diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 51f654fc7613aa..79902c8f5b894e 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1079,6 +1079,7 @@ def CUDADeviceBuiltinSurfaceType : InheritableAttr { let LangOpts = [CUDA]; let Subjects = SubjectList<[CXXRecord]>; let Documentation = [CUDADeviceBuiltinSurfaceTypeDocs]; + let MeaningfulToClassTemplateDefinition = 1; } def CUDADeviceBuiltinTextureType : InheritableAttr { @@ -1087,6 +1088,7 @@ def CUDADeviceBuiltinTextureType : InheritableAttr { let LangOpts = [CUDA]; let Subjects = SubjectList<[CXXRecord]>; let Documentation = [CUDADeviceBuiltinTextureTypeDocs]; + let MeaningfulToClassTemplateDefinition = 1; } def CUDAGlobal : InheritableAttr { diff --git a/clang/test/SemaCUDA/device-use-host-var.cu b/clang/test/SemaCUDA/device-use-host-var.cu index cf5514610a42ab..c8ef7dbbb18dca 100644 --- a/clang/test/SemaCUDA/device-use-host-var.cu +++ b/clang/test/SemaCUDA/device-use-host-var.cu @@ -158,3 +158,23 @@ void dev_lambda_capture_by_copy(int *out) { }); } +// Texture references are special. As far as C++ is concerned they are host +// variables that are referenced from device code. However, they are handled +// very differently by the compiler under the hood and such references are +// allowed. Compiler should produce no warning here, but it should diagnose the +// same case without the device_builtin_texture_type attribute. +template +struct __attribute__((device_builtin_texture_type)) texture { + static texture ref; + __device__ int c() { + auto &x = ref; + } +}; + +template +struct not_a_texture { + static not_a_texture ref; + __device__ int c() { + auto &x = ref; // dev-error {{reference to __host__ variable 'ref' in __device__ function}} + } +}; diff --git a/flang/include/flang/Common/restorer.h b/flang/include/flang/Common/restorer.h index 47e54237fe43cc..4d5f5e4e2c818d 100644 --- a/flang/include/flang/Common/restorer.h +++ b/flang/include/flang/Common/restorer.h @@ -22,9 +22,16 @@ namespace Fortran::common { template class Restorer { public: - explicit Restorer(A &p) : p_{p}, original_{std::move(p)} {} + explicit Restorer(A &p, A original) : p_{p}, original_{std::move(original)} {} ~Restorer() { p_ = std::move(original_); } + // Inhibit any recreation of this restorer that would result in two restorers + // trying to restore the same reference. + Restorer(const Restorer &) = delete; + Restorer(Restorer &&that) = delete; + const Restorer &operator=(const Restorer &) = delete; + const Restorer &operator=(Restorer &&that) = delete; + private: A &p_; A original_; @@ -32,15 +39,15 @@ template class Restorer { template common::IfNoLvalue, B> ScopedSet(A &to, B &&from) { - Restorer result{to}; + A original{std::move(to)}; to = std::move(from); - return result; + return Restorer{to, std::move(original)}; } template common::IfNoLvalue, B> ScopedSet(A &to, const B &from) { - Restorer result{to}; + A original{std::move(to)}; to = from; - return result; + return Restorer{to, std::move(original)}; } } // namespace Fortran::common #endif // FORTRAN_COMMON_RESTORER_H_ diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 4ffca03958042d..8d7a6d4af95076 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -249,7 +249,7 @@ class fir_AllocatableOp traits = []> : }]; let printer = [{ - p << getOperationName() << ' ' << getAttr(inType()); + p << getOperationName() << ' ' << (*this)->getAttr(inType()); if (hasLenParams()) { // print the LEN parameters to a derived type in parens p << '(' << getLenParams() << " : " << getLenParams().getTypes() << ')'; @@ -267,7 +267,7 @@ class fir_AllocatableOp traits = []> : static constexpr llvm::StringRef lenpName() { return "len_param_count"; } mlir::Type getAllocatedType(); - bool hasLenParams() { return bool{getAttr(lenpName())}; } + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } unsigned numLenParams() { if (auto val = (*this)->getAttrOfType(lenpName())) @@ -688,7 +688,7 @@ class fir_IntegralSwitchTerminatorOp(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) @@ -711,7 +711,7 @@ class fir_IntegralSwitchTerminatorOp() || getSelector().getType().isa())) return emitOpError("must be an integer"); - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); @@ -810,7 +810,7 @@ def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> { p << getOperationName() << ' '; p.printOperand(getSelector()); p << " : " << getSelector().getType() << " ["; - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) @@ -839,7 +839,7 @@ def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> { getSelector().getType().isa() || getSelector().getType().isa())) return emitOpError("must be an integer, character, or logical"); - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); @@ -925,7 +925,7 @@ def fir_SelectTypeOp : fir_SwitchTerminatorOp<"select_type"> { p << getOperationName() << ' '; p.printOperand(getSelector()); p << " : " << getSelector().getType() << " ["; - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) @@ -941,7 +941,7 @@ def fir_SelectTypeOp : fir_SwitchTerminatorOp<"select_type"> { let verifier = [{ if (!(getSelector().getType().isa())) return emitOpError("must be a boxed type"); - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); @@ -1056,7 +1056,7 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> { if (getNumOperands() == 2) { p << ", "; p.printOperands(dims()); - } else if (auto map = getAttr(layoutName())) { + } else if (auto map = (*this)->getAttr(layoutName())) { p << " [" << map << ']'; } p.printOptionalAttrDict(getAttrs(), {layoutName(), lenpName()}); @@ -1097,9 +1097,9 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> { let extraClassDeclaration = [{ static constexpr llvm::StringRef layoutName() { return "layout_map"; } static constexpr llvm::StringRef lenpName() { return "len_param_count"; } - bool hasLenParams() { return bool{getAttr(lenpName())}; } + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } unsigned numLenParams() { - if (auto x = getAttrOfType(lenpName())) + if (auto x = (*this)->getAttrOfType(lenpName())) return x.getInt(); return 0; } @@ -1213,13 +1213,13 @@ def fir_EmboxProcOp : fir_Op<"emboxproc", [NoSideEffect]> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr("funcname"); + p << getOperationName() << ' ' << (*this)->getAttr("funcname"); auto h = host(); if (h) { p << ", "; p.printOperand(h); } - p << " : (" << getAttr("functype"); + p << " : (" << (*this)->getAttr("functype"); if (h) p << ", " << h.getType(); p << ") -> " << getType(); @@ -1587,7 +1587,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { if (!ref().getType().dyn_cast()) return emitOpError("len_param_index must be used on box type"); } - if (auto attr = getAttr(CoordinateOp::baseType())) { + if (auto attr = (*this)->getAttr(CoordinateOp::baseType())) { if (!attr.isa()) return emitOpError("improperly constructed"); } else { @@ -1690,8 +1690,8 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> { let printer = [{ p << getOperationName() << ' ' - << getAttrOfType(fieldAttrName()).getValue() << ", " - << getAttr(typeAttrName()); + << (*this)->getAttrOfType(fieldAttrName()).getValue() + << ", " << (*this)->getAttr(typeAttrName()); if (getNumOperands()) { p << '('; p.printOperands(lenparams()); @@ -1826,8 +1826,8 @@ def fir_LenParamIndexOp : fir_OneResultOp<"len_param_index", [NoSideEffect]> { let printer = [{ p << getOperationName() << ' ' - << getAttrOfType(fieldAttrName()).getValue() << ", " - << getAttr(typeAttrName()); + << (*this)->getAttrOfType(fieldAttrName()).getValue() + << ", " << (*this)->getAttr(typeAttrName()); }]; let builders = [ @@ -1841,7 +1841,7 @@ def fir_LenParamIndexOp : fir_OneResultOp<"len_param_index", [NoSideEffect]> { static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } static constexpr llvm::StringRef typeAttrName() { return "on_type"; } mlir::Type getOnType() { - return getAttrOfType(typeAttrName()).getValue(); + return (*this)->getAttrOfType(typeAttrName()).getValue(); } }]; } @@ -2166,7 +2166,7 @@ def fir_DispatchOp : fir_Op<"dispatch", }]; let printer = [{ - p << getOperationName() << ' ' << getAttr("method") << '('; + p << getOperationName() << ' ' << (*this)->getAttr("method") << '('; p.printOperand(object()); if (arg_operand_begin() != arg_operand_end()) { p << ", "; @@ -2250,7 +2250,7 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> { auto eleTy = getType().cast().getEleTy(); if (!eleTy.isa()) return emitOpError("must have !fir.char type"); - if (auto xl = getAttr(xlist())) { + if (auto xl = (*this)->getAttr(xlist())) { auto xList = xl.cast(); for (auto a : xList) if (!a.isa()) @@ -2265,12 +2265,12 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> { static constexpr const char *xlist() { return "xlist"; } // Get the LEN attribute of this character constant - mlir::Attribute getSize() { return getAttr(size()); } + mlir::Attribute getSize() { return (*this)->getAttr(size()); } // Get the string value of this character constant mlir::Attribute getValue() { - if (auto attr = getAttr(value())) + if (auto attr = (*this)->getAttr(value())) return attr; - return getAttr(xlist()); + return (*this)->getAttr(xlist()); } /// Is this a wide character literal (1 character > 8 bits) @@ -2381,7 +2381,7 @@ def fir_CmpfOp : fir_Op<"cmpf", static CmpFPredicate getPredicateByName(llvm::StringRef name); CmpFPredicate getPredicate() { - return (CmpFPredicate)getAttrOfType( + return (CmpFPredicate)(*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -2415,11 +2415,11 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { let printer = [{ p << getOperationName() << " (0x"; - auto f1 = getAttr(realAttrName()).cast(); + auto f1 = (*this)->getAttr(realAttrName()).cast(); auto i1 = f1.getValue().bitcastToAPInt(); p.getStream().write_hex(i1.getZExtValue()); p << ", 0x"; - auto f2 = getAttr(imagAttrName()).cast(); + auto f2 = (*this)->getAttr(imagAttrName()).cast(); auto i2 = f2.getValue().bitcastToAPInt(); p.getStream().write_hex(i2.getZExtValue()); p << ") : "; @@ -2436,8 +2436,8 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { static constexpr llvm::StringRef realAttrName() { return "real"; } static constexpr llvm::StringRef imagAttrName() { return "imaginary"; } - mlir::Attribute getReal() { return getAttr(realAttrName()); } - mlir::Attribute getImaginary() { return getAttr(imagAttrName()); } + mlir::Attribute getReal() { return (*this)->getAttr(realAttrName()); } + mlir::Attribute getImaginary() { return (*this)->getAttr(imagAttrName()); } }]; } @@ -2485,7 +2485,7 @@ def fir_CmpcOp : fir_Op<"cmpc", } CmpFPredicate getPredicate() { - return (CmpFPredicate)getAttrOfType( + return (CmpFPredicate)(*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -2601,7 +2601,7 @@ def fir_GenTypeDescOp : fir_OneResultOp<"gentypedesc", [NoSideEffect]> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr("in_type"); + p << getOperationName() << ' ' << (*this)->getAttr("in_type"); p.printOptionalAttrDict(getAttrs(), {"in_type"}); }]; @@ -2623,7 +2623,7 @@ def fir_GenTypeDescOp : fir_OneResultOp<"gentypedesc", [NoSideEffect]> { let extraClassDeclaration = [{ mlir::Type getInType() { // get the type that the type descriptor describes - return getAttrOfType("in_type").getValue(); + return (*this)->getAttrOfType("in_type").getValue(); } }]; } @@ -2697,7 +2697,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { if (linkName().hasValue()) p << ' ' << linkName().getValue(); p << ' '; - p.printAttributeWithoutType(getAttr(symbolAttrName())); + p.printAttributeWithoutType((*this)->getAttr(symbolAttrName())); if (auto val = getValueOrNull()) p << '(' << val << ')'; if ((*this)->getAttr(constantAttrName())) @@ -2738,7 +2738,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { /// The printable type of the global mlir::Type getType() { - return getAttrOfType(typeAttrName()).getValue(); + return (*this)->getAttrOfType(typeAttrName()).getValue(); } /// The semantic type of the global @@ -2768,8 +2768,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { } mlir::FlatSymbolRefAttr getSymbol() { - return mlir::FlatSymbolRefAttr::get(getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); + return mlir::FlatSymbolRefAttr::get( + (*this)->getAttrOfType( + mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); } }]; } @@ -2811,8 +2812,8 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr(lenParamAttrName()) << ", " - << getAttr(intAttrName()); + p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) + << ", " << (*this)->getAttr(intAttrName()); }]; let extraClassDeclaration = [{ @@ -2865,7 +2866,7 @@ def fir_DispatchTableOp : fir_Op<"dispatch_table", }]; let printer = [{ - auto tableName = getAttrOfType( + auto tableName = (*this)->getAttrOfType( mlir::SymbolTable::getSymbolAttrName()).getValue(); p << getOperationName() << " @" << tableName; @@ -2946,8 +2947,8 @@ def fir_DTEntryOp : fir_Op<"dt_entry", []> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr(methodAttrName()) << ", " - << getAttr(procAttrName()); + p << getOperationName() << ' ' << (*this)->getAttr(methodAttrName()) << ", " + << (*this)->getAttr(procAttrName()); }]; let extraClassDeclaration = [{ diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h index 43588ff17962ed..2d9ad28981ab24 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -18,7 +18,7 @@ namespace fir { /// return true iff the Operation is a non-volatile LoadOp inline bool nonVolatileLoad(mlir::Operation *op) { if (auto load = dyn_cast(op)) - return !load.getAttr("volatile"); + return !load->getAttr("volatile"); return false; } diff --git a/flang/lib/Lower/CharacterRuntime.cpp b/flang/lib/Lower/CharacterRuntime.cpp index af95885f985d2a..4bfbf5824efbbe 100644 --- a/flang/lib/Lower/CharacterRuntime.cpp +++ b/flang/lib/Lower/CharacterRuntime.cpp @@ -62,7 +62,7 @@ static mlir::FuncOp getRuntimeFunc(mlir::Location loc, return func; auto funTy = getTypeModel()(builder.getContext()); func = builder.createFunction(loc, name, funTy); - func.setAttr("fir.runtime", builder.getUnitAttr()); + func->setAttr("fir.runtime", builder.getUnitAttr()); return func; } diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp index 3f79b79e32ee9a..ab7387dd3ce6e1 100644 --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -123,8 +123,8 @@ static mlir::FuncOp getIORuntimeFunc(mlir::Location loc, return func; auto funTy = getTypeModel()(builder.getContext()); func = builder.createFunction(loc, name, funTy); - func.setAttr("fir.runtime", builder.getUnitAttr()); - func.setAttr("fir.io", builder.getUnitAttr()); + func->setAttr("fir.runtime", builder.getUnitAttr()); + func->setAttr("fir.io", builder.getUnitAttr()); return func; } diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp index 23b084eaf67d5f..0e0081ef664cfe 100644 --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -497,7 +497,7 @@ static mlir::FuncOp getFuncOp(mlir::Location loc, const RuntimeFunction &runtime) { auto function = builder.addNamedFunction( loc, runtime.symbol, runtime.typeGenerator(builder.getContext())); - function.setAttr("fir.runtime", builder.getUnitAttr()); + function->setAttr("fir.runtime", builder.getUnitAttr()); return function; } @@ -769,7 +769,7 @@ mlir::FuncOp IntrinsicLibrary::getWrapper(GeneratorType generator, if (!function) { // First time this wrapper is needed, build it. function = builder.createFunction(loc, wrapperName, funcType); - function.setAttr("fir.intrinsic", builder.getUnitAttr()); + function->setAttr("fir.intrinsic", builder.getUnitAttr()); function.addEntryBlock(); // Create local context to emit code into the newly created function diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 6f45bb623d7db9..12cf97869543aa 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -98,8 +98,8 @@ static Op createRegionOp(Fortran::lower::FirOpBuilder &builder, builder.setInsertionPointToStart(&block); builder.create(loc); - op.setAttr(Op::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr(operandSegments)); + op->setAttr(Op::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr(operandSegments)); // Place the insertion point to the start of the first block. builder.setInsertionPointToStart(&block); @@ -114,8 +114,8 @@ static Op createSimpleOp(Fortran::lower::FirOpBuilder &builder, const SmallVectorImpl &operandSegments) { llvm::ArrayRef argTy; Op op = builder.create(loc, argTy, operands); - op.setAttr(Op::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr(operandSegments)); + op->setAttr(Op::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr(operandSegments)); return op; } @@ -231,8 +231,8 @@ static void genACC(Fortran::lower::AbstractConverter &converter, auto loopOp = createRegionOp( firOpBuilder, currentLocation, operands, operandSegments); - loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), - firOpBuilder.getI64IntegerAttr(executionMapping)); + loopOp->setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), + firOpBuilder.getI64IntegerAttr(executionMapping)); // Lower clauses mapped to attributes for (const auto &clause : accClauseList.v) { @@ -241,19 +241,19 @@ static void genACC(Fortran::lower::AbstractConverter &converter, const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); const auto collapseValue = Fortran::evaluate::ToInt64(*expr); if (collapseValue) { - loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(), - firOpBuilder.getI64IntegerAttr(*collapseValue)); + loopOp->setAttr(mlir::acc::LoopOp::getCollapseAttrName(), + firOpBuilder.getI64IntegerAttr(*collapseValue)); } } else if (std::get_if(&clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(), - firOpBuilder.getUnitAttr()); + loopOp->setAttr(mlir::acc::LoopOp::getSeqAttrName(), + firOpBuilder.getUnitAttr()); } else if (std::get_if( &clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(), - firOpBuilder.getUnitAttr()); + loopOp->setAttr(mlir::acc::LoopOp::getIndependentAttrName(), + firOpBuilder.getUnitAttr()); } else if (std::get_if(&clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(), - firOpBuilder.getUnitAttr()); + loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrName(), + firOpBuilder.getUnitAttr()); } } } @@ -425,14 +425,14 @@ genACCParallelOp(Fortran::lower::AbstractConverter &converter, firOpBuilder, currentLocation, operands, operandSegments); if (addAsyncAttr) - parallelOp.setAttr(mlir::acc::ParallelOp::getAsyncAttrName(), - firOpBuilder.getUnitAttr()); + parallelOp->setAttr(mlir::acc::ParallelOp::getAsyncAttrName(), + firOpBuilder.getUnitAttr()); if (addWaitAttr) - parallelOp.setAttr(mlir::acc::ParallelOp::getWaitAttrName(), - firOpBuilder.getUnitAttr()); + parallelOp->setAttr(mlir::acc::ParallelOp::getWaitAttrName(), + firOpBuilder.getUnitAttr()); if (addSelfAttr) - parallelOp.setAttr(mlir::acc::ParallelOp::getSelfAttrName(), - firOpBuilder.getUnitAttr()); + parallelOp->setAttr(mlir::acc::ParallelOp::getSelfAttrName(), + firOpBuilder.getUnitAttr()); } static void genACCDataOp(Fortran::lower::AbstractConverter &converter, diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 32d04b79c058f7..4a6c8d50e2ae04 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -166,7 +166,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { p << callee.getValue(); else p << op.getOperand(0); - p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; + p << '(' << op->getOperands().drop_front(isDirect ? 0 : 1) << ')'; p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); auto resultTypes{op.getResultTypes()}; llvm::SmallVector argTypes( @@ -240,7 +240,8 @@ template static void printCmpOp(OpAsmPrinter &p, OPTY op) { p << op.getOperationName() << ' '; auto predSym = mlir::symbolizeCmpFPredicate( - op.template getAttrOfType(OPTY::getPredicateAttrName()) + op->template getAttrOfType( + OPTY::getPredicateAttrName()) .getInt()); assert(predSym.hasValue() && "invalid symbol value for predicate"); p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; @@ -385,7 +386,10 @@ static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, } mlir::Type fir::CoordinateOp::getBaseType() { - return getAttr(CoordinateOp::baseType()).cast().getValue(); + return (*this) + ->getAttr(CoordinateOp::baseType()) + .cast() + .getValue(); } void fir::CoordinateOp::build(OpBuilder &, OperationState &result, @@ -412,7 +416,7 @@ void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, //===----------------------------------------------------------------------===// mlir::FunctionType fir::DispatchOp::getFunctionType() { - auto attr = getAttr("fn_type").cast(); + auto attr = (*this)->getAttr("fn_type").cast(); return attr.getValue().cast(); } @@ -745,7 +749,7 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << op.getResultTypes().drop_front() << ')'; } - p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); + p.printOptionalAttrDictWithKeyword(op->getAttrs(), {}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -930,7 +934,7 @@ static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { p << ") -> (" << op.getResultTypes() << ')'; printBlockTerminators = true; } - p.printOptionalAttrDictWithKeyword(op.getAttrs(), + p.printOptionalAttrDictWithKeyword(op->getAttrs(), {fir::LoopOp::unorderedAttrName()}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); @@ -963,9 +967,9 @@ mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef opnds) { //===----------------------------------------------------------------------===// static mlir::LogicalResult verify(fir::ResultOp op) { - auto parentOp = op.getParentOp(); + auto *parentOp = op->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = op->getOperands(); if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "parent of result must have same arity"; @@ -1032,15 +1036,16 @@ fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// @@ -1049,16 +1054,18 @@ unsigned fir::SelectOp::targetOffsetSize() { llvm::Optional fir::SelectCaseOp::getCompareOperands(unsigned cond) { - auto a = getAttrOfType(getCompareOffsetAttr()); + auto a = (*this)->getAttrOfType( + getCompareOffsetAttr()); return {getSubOperands(cond, compareArgs(), a)}; } llvm::Optional> fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef operands, unsigned cond) { - auto a = getAttrOfType(getCompareOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = (*this)->getAttrOfType( + getCompareOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } @@ -1071,9 +1078,10 @@ fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } @@ -1152,13 +1160,13 @@ static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, } unsigned fir::SelectCaseOp::compareOffsetSize() { - return denseElementsSize( - getAttrOfType(getCompareOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getCompareOffsetAttr())); } unsigned fir::SelectCaseOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } void fir::SelectCaseOp::build(mlir::OpBuilder &builder, @@ -1262,15 +1270,16 @@ fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectRankOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// @@ -1296,9 +1305,10 @@ fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } @@ -1348,8 +1358,8 @@ static ParseResult parseSelectType(OpAsmParser &parser, } unsigned fir::SelectTypeOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// @@ -1467,7 +1477,7 @@ static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { p.printRegion(otherReg, /*printEntryBlockArgs=*/false, printBlockTerminators); } - p.printOptionalAttrDict(op.getAttrs()); + p.printOptionalAttrDict(op->getAttrs()); } //===----------------------------------------------------------------------===// diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index 0d2e2e86241c66..dd76f67100701e 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -1498,7 +1498,7 @@ void CheckHelper::CheckProcBinding( void CheckHelper::Check(const Scope &scope) { scope_ = &scope; - common::Restorer restorer{innermostSymbol_}; + common::Restorer restorer{innermostSymbol_, innermostSymbol_}; if (const Symbol * symbol{scope.symbol()}) { innermostSymbol_ = symbol; } diff --git a/llvm/include/llvm/Support/GenericDomTree.h b/llvm/include/llvm/Support/GenericDomTree.h index 4bed550f44c0f2..d2d7c8c4481d6d 100644 --- a/llvm/include/llvm/Support/GenericDomTree.h +++ b/llvm/include/llvm/Support/GenericDomTree.h @@ -554,7 +554,6 @@ class DominatorTreeBase { /// obtained PostViewCFG is the desired end state. void applyUpdates(ArrayRef Updates, ArrayRef PostViewUpdates) { - // GraphDiff *PostViewCFG = nullptr) { if (Updates.empty()) { GraphDiff PostViewCFG(PostViewUpdates); DomTreeBuilder::ApplyUpdates(*this, PostViewCFG, &PostViewCFG); diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 65d39161c1be48..be340a3b3130da 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -1781,26 +1781,6 @@ bool LoopAccessInfo::canAnalyzeLoop() { return false; } - // We must have a single exiting block. - if (!TheLoop->getExitingBlock()) { - LLVM_DEBUG( - dbgs() << "LAA: loop control flow is not understood by analyzer\n"); - recordAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by analyzer"; - return false; - } - - // We only handle bottom-tested loops, i.e. loop in which the condition is - // checked at the end of each iteration. With that we can assume that all - // instructions in the loop are executed the same number of times. - if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { - LLVM_DEBUG( - dbgs() << "LAA: loop control flow is not understood by analyzer\n"); - recordAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by analyzer"; - return false; - } - // ScalarEvolution needs to be able to find the exit count. const SCEV *ExitCount = PSE->getBackedgeTakenCount(); if (isa(ExitCount)) { diff --git a/llvm/lib/CodeGen/MachineBlockPlacement.cpp b/llvm/lib/CodeGen/MachineBlockPlacement.cpp index bd4640822a6315..42586cbe06e097 100644 --- a/llvm/lib/CodeGen/MachineBlockPlacement.cpp +++ b/llvm/lib/CodeGen/MachineBlockPlacement.cpp @@ -2306,6 +2306,10 @@ void MachineBlockPlacement::rotateLoop(BlockChain &LoopChain, if (Bottom == ExitingBB) return; + // The entry block should always be the first BB in a function. + if (Top->isEntryBlock()) + return; + bool ViableTopFallthrough = hasViableTopFallthrough(Top, LoopBlockSet); // If the header has viable fallthrough, check whether the current loop @@ -2380,6 +2384,11 @@ void MachineBlockPlacement::rotateLoopWithProfile( BlockChain &LoopChain, const MachineLoop &L, const BlockFilterSet &LoopBlockSet) { auto RotationPos = LoopChain.end(); + MachineBasicBlock *ChainHeaderBB = *LoopChain.begin(); + + // The entry block should always be the first BB in a function. + if (ChainHeaderBB->isEntryBlock()) + return; BlockFrequency SmallestRotationCost = BlockFrequency::getMaxFrequency(); @@ -2398,7 +2407,6 @@ void MachineBlockPlacement::rotateLoopWithProfile( // chain head is not the loop header. As we only consider natural loops with // single header, this computation can be done only once. BlockFrequency HeaderFallThroughCost(0); - MachineBasicBlock *ChainHeaderBB = *LoopChain.begin(); for (auto *Pred : ChainHeaderBB->predecessors()) { BlockChain *PredChain = BlockToChain[Pred]; if (!LoopBlockSet.count(Pred) && diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp index 4c0c202be4bef9..2d90e37349e57b 100644 --- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp @@ -292,7 +292,7 @@ template <> StoreInst *isCandidate(Instruction *In) { return getIfUnordered(dyn_cast(In)); } -#if !defined(_MSC_VER) || _MSC_VER >= 1920 +#if !defined(_MSC_VER) || _MSC_VER >= 1924 // VS2017 has trouble compiling this: // error C2976: 'std::map': too few template arguments template diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 98d67efef922ea..3dd7d9dce67ac0 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -670,15 +670,17 @@ class LoopDistributeForLoop { << L->getHeader()->getParent()->getName() << "\" checking " << *L << "\n"); + // Having a single exit block implies there's also one exiting block. if (!L->getExitBlock()) return fail("MultipleExitBlocks", "multiple exit blocks"); if (!L->isLoopSimplifyForm()) return fail("NotLoopSimplifyForm", "loop is not in loop-simplify form"); + if (!L->isRotatedForm()) + return fail("NotBottomTested", "loop is not bottom tested"); BasicBlock *PH = L->getLoopPreheader(); - // LAA will check that we only have a single exiting block. LAI = &GetLAA(*L); // Currently, we only distribute to isolate the part of the loop with diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 475448740ae471..56afddead6191a 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -632,6 +632,9 @@ eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, // Now walk the identified inner loops. for (Loop *L : Worklist) { + // Match historical behavior + if (!L->isRotatedForm() || !L->getExitingBlock()) + continue; // The actual work is performed by LoadEliminationForLoop. LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT, BFI, PSI); Changed |= LEL.processLoop(); diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 03eb41b5ee0d3d..b605cb2fb865bf 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -269,8 +269,11 @@ bool runImpl(LoopInfo *LI, function_ref GetLAA, // Now walk the identified inner loops. bool Changed = false; for (Loop *L : Worklist) { + if (!L->isLoopSimplifyForm() || !L->isRotatedForm() || + !L->getExitingBlock()) + continue; const LoopAccessInfo &LAI = GetLAA(*L); - if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && + if (!LAI.hasConvergentOp() && (LAI.getNumRuntimePointerChecks() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L, diff --git a/llvm/test/Transforms/VectorCombine/X86/load.ll b/llvm/test/Transforms/VectorCombine/X86/load.ll index 824a507ed1036b..ba2bf3f37d7b67 100644 --- a/llvm/test/Transforms/VectorCombine/X86/load.ll +++ b/llvm/test/Transforms/VectorCombine/X86/load.ll @@ -535,3 +535,20 @@ define <8 x i32> @load_v1i32_extract_insert_v8i32_extra_use(<1 x i32>* align 16 %r = insertelement <8 x i32> undef, i32 %s, i32 0 ret <8 x i32> %r } + +; TODO: Can't safely load the offset vector, but can load+shuffle if it is profitable. + +define <8 x i16> @gep1_load_v2i16_extract_insert_v8i16(<2 x i16>* align 16 dereferenceable(16) %p) { +; CHECK-LABEL: @gep1_load_v2i16_extract_insert_v8i16( +; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds <2 x i16>, <2 x i16>* [[P:%.*]], i64 1 +; CHECK-NEXT: [[L:%.*]] = load <2 x i16>, <2 x i16>* [[GEP]], align 2 +; CHECK-NEXT: [[S:%.*]] = extractelement <2 x i16> [[L]], i32 0 +; CHECK-NEXT: [[R:%.*]] = insertelement <8 x i16> undef, i16 [[S]], i64 0 +; CHECK-NEXT: ret <8 x i16> [[R]] +; + %gep = getelementptr inbounds <2 x i16>, <2 x i16>* %p, i64 1 + %l = load <2 x i16>, <2 x i16>* %gep, align 2 + %s = extractelement <2 x i16> %l, i32 0 + %r = insertelement <8 x i16> undef, i16 %s, i64 0 + ret <8 x i16> %r +} diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index a267a60adc3ef2..189cd0825af7eb 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -646,6 +646,30 @@ The following are the set of valid punctuation: `:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`, `?`, `+`, `*` +The following are valid whitespace punctuation: + +`\n`, ` ` + +The `\n` literal emits a newline an indents to the start of the operation. An +example is shown below: + +```tablegen +let assemblyFormat = [{ + `{` `\n` ` ` ` ` `this_is_on_a_newline` `\n` `}` attr-dict +}]; +``` + +```mlir +%results = my.operation { + this_is_on_a_newline +} +``` + +An empty literal \`\` may be used to remove a space that is inserted implicitly +after certain literal elements, such as `)`/`]`/etc. For example, "`]`" may +result in an output of `]` it is not the last element in the format. "`]` \`\`" +would trim the trailing space in this situation. + #### Variables A variable is an entity that has been registered on the operation itself, i.e. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b200f7c2dc6cf2..2edc48dd099e6e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -108,27 +108,6 @@ class OpState { /// Return the operation that this refers to. Operation *getOperation() { return state; } - /// Return the dialect that this refers to. - Dialect *getDialect() { return getOperation()->getDialect(); } - - /// Return the parent Region of this operation. - Region *getParentRegion() { return getOperation()->getParentRegion(); } - - /// Returns the closest surrounding operation that contains this operation - /// or nullptr if this is a top-level operation. - Operation *getParentOp() { return getOperation()->getParentOp(); } - - /// Return the closest surrounding parent operation that is of type 'OpTy'. - template OpTy getParentOfType() { - return getOperation()->getParentOfType(); - } - - /// Returns the closest surrounding parent operation with trait `Trait`. - template