diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp index f2c2bfd3b1a9c3..853707e9c23ddc 100644 --- a/llvm/lib/Analysis/StackSafetyAnalysis.cpp +++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp @@ -341,6 +341,13 @@ bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr, return false; } + unsigned ArgNo = CB.getArgOperandNo(&UI); + if (CB.isByValArgument(ArgNo)) { + US.updateRange(getAccessRange( + UI, Ptr, DL.getTypeStoreSize(CB.getParamByValType(ArgNo)))); + break; + } + // FIXME: consult devirt? // Do not follow aliases, otherwise we could inadvertently follow // dso_preemptable aliases or aliases with interposable linkage. @@ -352,8 +359,7 @@ bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr, } assert(isa(Callee) || isa(Callee)); - US.Calls.emplace_back(Callee, CB.getArgOperandNo(&UI), - offsetFrom(UI, Ptr)); + US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr)); break; } @@ -382,7 +388,9 @@ FunctionInfo StackSafetyLocalAnalysis::run() { } for (Argument &A : make_range(F.arg_begin(), F.arg_end())) { - if (A.getType()->isPointerTy()) { + // Non pointers and bypass arguments are not going to be used in any global + // processing. + if (A.getType()->isPointerTy() && !A.hasByValAttr()) { auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second; analyzeAllUses(&A, UI); } diff --git a/llvm/lib/Support/Host.cpp b/llvm/lib/Support/Host.cpp index 930abc3ea1bac4..be46a53ea73a88 100644 --- a/llvm/lib/Support/Host.cpp +++ b/llvm/lib/Support/Host.cpp @@ -585,6 +585,16 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, unsigned Brand_id, unsigned Features, unsigned Features2, unsigned Features3, unsigned *Type, unsigned *Subtype) { + auto testFeature = [&](unsigned F) { + if (F < 32) + return (Features & (1U << (F & 0x1f))) != 0; + if (F < 64) + return (Features2 & (1U << ((F - 32) & 0x1f))) != 0; + if (F < 96) + return (Features3 & (1U << ((F - 64) & 0x1f))) != 0; + llvm_unreachable("Unexpected FeatureBit"); + }; + if (Brand_id != 0) return; switch (Family) { @@ -595,7 +605,7 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, *Type = X86::INTEL_i486; break; case 5: - if (Features & (1 << X86::FEATURE_MMX)) { + if (testFeature(X86::FEATURE_MMX)) { *Type = X86::INTEL_PENTIUM_MMX; break; } @@ -711,9 +721,9 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, // Skylake Xeon: case 0x55: *Type = X86::INTEL_COREI7; - if (Features2 & (1 << (X86::FEATURE_AVX512BF16 - 32))) + if (testFeature(X86::FEATURE_AVX512BF16)) *Subtype = X86::INTEL_COREI7_COOPERLAKE; // "cooperlake" - else if (Features2 & (1 << (X86::FEATURE_AVX512VNNI - 32))) + else if (testFeature(X86::FEATURE_AVX512VNNI)) *Subtype = X86::INTEL_COREI7_CASCADELAKE; // "cascadelake" else *Subtype = X86::INTEL_COREI7_SKYLAKE_AVX512; // "skylake-avx512" @@ -777,50 +787,50 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, break; default: // Unknown family 6 CPU, try to guess. - // TODO detect tigerlake host - if (Features2 & (1 << (X86::FEATURE_AVX512VP2INTERSECT - 32))) { + // TODO detect tigerlake host from model + if (testFeature(X86::FEATURE_AVX512VP2INTERSECT)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_TIGERLAKE; break; } - if (Features & (1 << X86::FEATURE_AVX512VBMI2)) { + if (testFeature(X86::FEATURE_AVX512VBMI2)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_ICELAKE_CLIENT; break; } - if (Features & (1 << X86::FEATURE_AVX512VBMI)) { + if (testFeature(X86::FEATURE_AVX512VBMI)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_CANNONLAKE; break; } - if (Features2 & (1 << (X86::FEATURE_AVX512BF16 - 32))) { + if (testFeature(X86::FEATURE_AVX512BF16)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_COOPERLAKE; break; } - if (Features2 & (1 << (X86::FEATURE_AVX512VNNI - 32))) { + if (testFeature(X86::FEATURE_AVX512VNNI)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_CASCADELAKE; break; } - if (Features & (1 << X86::FEATURE_AVX512VL)) { + if (testFeature(X86::FEATURE_AVX512VL)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_SKYLAKE_AVX512; break; } - if (Features & (1 << X86::FEATURE_AVX512ER)) { + if (testFeature(X86::FEATURE_AVX512ER)) { *Type = X86::INTEL_KNL; // knl break; } - if (Features3 & (1 << (X86::FEATURE_CLFLUSHOPT - 64))) { - if (Features3 & (1 << (X86::FEATURE_SHA - 64))) { + if (testFeature(X86::FEATURE_CLFLUSHOPT)) { + if (testFeature(X86::FEATURE_SHA)) { *Type = X86::INTEL_GOLDMONT; } else { *Type = X86::INTEL_COREI7; @@ -828,23 +838,23 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, } break; } - if (Features3 & (1 << (X86::FEATURE_ADX - 64))) { + if (testFeature(X86::FEATURE_ADX)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_BROADWELL; break; } - if (Features & (1 << X86::FEATURE_AVX2)) { + if (testFeature(X86::FEATURE_AVX2)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_HASWELL; break; } - if (Features & (1 << X86::FEATURE_AVX)) { + if (testFeature(X86::FEATURE_AVX)) { *Type = X86::INTEL_COREI7; *Subtype = X86::INTEL_COREI7_SANDYBRIDGE; break; } - if (Features & (1 << X86::FEATURE_SSE4_2)) { - if (Features3 & (1 << (X86::FEATURE_MOVBE - 64))) { + if (testFeature(X86::FEATURE_SSE4_2)) { + if (testFeature(X86::FEATURE_MOVBE)) { *Type = X86::INTEL_SILVERMONT; } else { *Type = X86::INTEL_COREI7; @@ -852,13 +862,13 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, } break; } - if (Features & (1 << X86::FEATURE_SSE4_1)) { + if (testFeature(X86::FEATURE_SSE4_1)) { *Type = X86::INTEL_CORE2; // "penryn" *Subtype = X86::INTEL_CORE2_45; break; } - if (Features & (1 << X86::FEATURE_SSSE3)) { - if (Features3 & (1 << (X86::FEATURE_MOVBE - 64))) { + if (testFeature(X86::FEATURE_SSSE3)) { + if (testFeature(X86::FEATURE_MOVBE)) { *Type = X86::INTEL_BONNELL; // "bonnell" } else { *Type = X86::INTEL_CORE2; // "core2" @@ -866,24 +876,24 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, } break; } - if (Features3 & (1 << (X86::FEATURE_EM64T - 64))) { + if (testFeature(X86::FEATURE_EM64T)) { *Type = X86::INTEL_CORE2; // "core2" *Subtype = X86::INTEL_CORE2_65; break; } - if (Features & (1 << X86::FEATURE_SSE3)) { + if (testFeature(X86::FEATURE_SSE3)) { *Type = X86::INTEL_CORE_DUO; break; } - if (Features & (1 << X86::FEATURE_SSE2)) { + if (testFeature(X86::FEATURE_SSE2)) { *Type = X86::INTEL_PENTIUM_M; break; } - if (Features & (1 << X86::FEATURE_SSE)) { + if (testFeature(X86::FEATURE_SSE)) { *Type = X86::INTEL_PENTIUM_III; break; } - if (Features & (1 << X86::FEATURE_MMX)) { + if (testFeature(X86::FEATURE_MMX)) { *Type = X86::INTEL_PENTIUM_II; break; } @@ -892,11 +902,11 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, } break; case 15: { - if (Features3 & (1 << (X86::FEATURE_EM64T - 64))) { + if (testFeature(X86::FEATURE_EM64T)) { *Type = X86::INTEL_NOCONA; break; } - if (Features & (1 << X86::FEATURE_SSE3)) { + if (testFeature(X86::FEATURE_SSE3)) { *Type = X86::INTEL_PRESCOTT; break; } @@ -911,6 +921,12 @@ getIntelProcessorTypeAndSubtype(unsigned Family, unsigned Model, static void getAMDProcessorTypeAndSubtype(unsigned Family, unsigned Model, unsigned Features, unsigned *Type, unsigned *Subtype) { + auto testFeature = [&](unsigned F) { + if (F < 32) + return (Features & (1U << (F & 0x1f))) != 0; + llvm_unreachable("Unexpected FeatureBit"); + }; + // FIXME: this poorly matches the generated SubtargetFeatureKV table. There // appears to be no way to generate the wide variety of AMD-specific targets // from the information returned from CPUID. @@ -938,14 +954,14 @@ static void getAMDProcessorTypeAndSubtype(unsigned Family, unsigned Model, } break; case 6: - if (Features & (1 << X86::FEATURE_SSE)) { + if (testFeature(X86::FEATURE_SSE)) { *Type = X86::AMD_ATHLON_XP; break; // "athlon-xp" } *Type = X86::AMD_ATHLON; break; // "athlon" case 15: - if (Features & (1 << X86::FEATURE_SSE3)) { + if (testFeature(X86::FEATURE_SSE3)) { *Type = X86::AMD_K8SSE3; break; // "k8-sse3" } diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index d2d71dde4fd571..9d0500419a7f57 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -214,11 +214,16 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { if (!CondBr) return; - BranchProbability BP; uint64_t TrueWeight, FalseWeight; if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight)) return; + if (TrueWeight + FalseWeight == 0) + // Zero branch_weights do not give a hint for getting branch probabilities. + // Technically it would result in division by zero denominator, which is + // TrueWeight + FalseWeight. + return; + // Returns the outgoing edge of the dominating predecessor block // that leads to the PhiNode's incoming block: auto GetPredOutEdge = @@ -253,10 +258,11 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { if (!CI || !CI->getType()->isIntegerTy(1)) continue; - BP = (CI->isOne() ? BranchProbability::getBranchProbability( - TrueWeight, TrueWeight + FalseWeight) - : BranchProbability::getBranchProbability( - FalseWeight, TrueWeight + FalseWeight)); + BranchProbability BP = + (CI->isOne() ? BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight) + : BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB); if (!PredOutEdge.first) diff --git a/llvm/test/Analysis/StackSafetyAnalysis/local.ll b/llvm/test/Analysis/StackSafetyAnalysis/local.ll index 1b067fad566bd8..75cf64fb177701 100644 --- a/llvm/test/Analysis/StackSafetyAnalysis/local.ll +++ b/llvm/test/Analysis/StackSafetyAnalysis/local.ll @@ -416,3 +416,47 @@ entry: call void @LeakAddress() ["unknown"(i32* %a)] ret void } + +define void @ByVal(i16* byval %p) { + ; CHECK-LABEL: @ByVal dso_preemptable{{$}} + ; CHECK-NEXT: args uses: + ; CHECK-NEXT: allocas uses: + ; CHECK-NOT: ]: +entry: + ret void +} + +define void @TestByVal() { +; CHECK-LABEL: @TestByVal dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[2]: [0,2) +; CHECK-NEXT: y[8]: [0,2) +; CHECK-NOT: ]: +entry: + %x = alloca i16, align 4 + call void @ByVal(i16* byval %x) + + %y = alloca i64, align 4 + %y1 = bitcast i64* %y to i16* + call void @ByVal(i16* byval %y1) + + ret void +} + +declare void @ByValArray([100000 x i64]* byval %p) + +define void @TestByValArray() { +; CHECK-LABEL: @TestByValArray dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: z[800000]: [500000,1300000) +; CHECK-NOT: ]: +entry: + %z = alloca [100000 x i64], align 4 + %z1 = bitcast [100000 x i64]* %z to i8* + %z2 = getelementptr i8, i8* %z1, i64 500000 + %z3 = bitcast i8* %z2 to [100000 x i64]* + call void @ByValArray([100000 x i64]* byval %z3) + ret void +} \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 235c322d748896..b7180399a837fe 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -19,7 +19,6 @@ #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" -#include #include // Forward declare enum classes related to op availability. Their definitions @@ -277,40 +276,22 @@ class StructType : public Type::TypeBasememberIndex == other.memberIndex) && - (this->decoration == other.decoration) && - (this->decorationValue == other.decorationValue); - } + // Layout information used for members in a struct in SPIR-V + // + // TODO(ravishankarm) : For now this only supports the offset type, so uses + // uint64_t value to represent the offset, with + // std::numeric_limit::max indicating no offset. Change this to + // something that can hold all the information needed for different member + // types + using LayoutInfo = uint64_t; - bool operator<(const MemberDecorationInfo &other) const { - return this->memberIndex < other.memberIndex || - (this->memberIndex == other.memberIndex && - static_cast(this->decoration) < - static_cast(other.decoration)); - } - }; + using MemberDecorationInfo = std::pair; static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } /// Construct a StructType with at least one member. static StructType get(ArrayRef memberTypes, - ArrayRef offsetInfo = {}, + ArrayRef layoutInfo = {}, ArrayRef memberDecorations = {}); /// Construct a struct with no members. @@ -342,9 +323,9 @@ class StructType : public Type::TypeBase - &memberDecorations) const; + void getMemberDecorations( + unsigned i, SmallVectorImpl &memberDecorations) const; void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); @@ -363,9 +343,6 @@ class StructType : public Type::TypeBase storage = llvm::None); }; -llvm::hash_code -hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); - // SPIR-V cooperative matrix type class CooperativeMatrixNVType : public Type::TypeBase memberTypes; - SmallVector offsetInfo; + SmallVector layoutInfo; SmallVector memberDecorations; Size structMemberOffset = 0; @@ -46,8 +46,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType, decorateType(structType.getElementType(i), memberSize, memberAlignment); structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); memberTypes.push_back(memberType); - offsetInfo.push_back( - static_cast(structMemberOffset)); + layoutInfo.push_back(structMemberOffset); // If the member's size is the max value, it must be the last member and it // must be a runtime array. assert(memberSize != std::numeric_limits().max() || @@ -67,7 +66,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType, size = llvm::alignTo(structMemberOffset, maxMemberAlignment); alignment = maxMemberAlignment; structType.getMemberDecorations(memberDecorations); - return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); + return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations); } Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, @@ -169,7 +168,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) { case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::PhysicalStorageBuffer: - return structType.hasOffset() || !structType.getNumElements(); + return structType.hasLayout() || !structType.getNumElements(); default: return true; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 43e70a1bdc637e..455064f58ce699 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -535,31 +535,30 @@ static Type parseImageType(SPIRVDialect const &dialect, static ParseResult parseStructMemberDecorations( SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef memberTypes, - SmallVectorImpl &offsetInfo, + SmallVectorImpl &layoutInfo, SmallVectorImpl &memberDecorationInfo) { // Check if the first element is offset. - llvm::SMLoc offsetLoc = parser.getCurrentLocation(); - StructType::OffsetInfo offset = 0; - OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset); - if (offsetParseResult.hasValue()) { - if (failed(*offsetParseResult)) + llvm::SMLoc layoutLoc = parser.getCurrentLocation(); + StructType::LayoutInfo layout = 0; + OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout); + if (layoutParseResult.hasValue()) { + if (failed(*layoutParseResult)) return failure(); - if (offsetInfo.size() != memberTypes.size() - 1) { - return parser.emitError(offsetLoc, - "offset specification must be given for " - "all members"); + if (layoutInfo.size() != memberTypes.size() - 1) { + return parser.emitError( + layoutLoc, "layout specification must be given for all members"); } - offsetInfo.push_back(offset); + layoutInfo.push_back(layout); } // Check for no spirv::Decorations. if (succeeded(parser.parseOptionalRSquare())) return success(); - // If there was an offset, make sure to parse the comma. - if (offsetParseResult.hasValue() && parser.parseComma()) + // If there was a layout, make sure to parse the comma. + if (layoutParseResult.hasValue() && parser.parseComma()) return failure(); // Check for spirv::Decorations. @@ -568,23 +567,9 @@ static ParseResult parseStructMemberDecorations( if (!memberDecoration) return failure(); - // Parse member decoration value if it exists. - if (succeeded(parser.parseOptionalEqual())) { - auto memberDecorationValue = - parseAndVerifyInteger(dialect, parser); - - if (!memberDecorationValue) - return failure(); - - memberDecorationInfo.emplace_back( - static_cast(memberTypes.size() - 1), 1, - memberDecoration.getValue(), memberDecorationValue.getValue()); - } else { - memberDecorationInfo.emplace_back( - static_cast(memberTypes.size() - 1), 0, - memberDecoration.getValue(), 0); - } - + memberDecorationInfo.emplace_back( + static_cast(memberTypes.size() - 1), + memberDecoration.getValue()); } while (succeeded(parser.parseOptionalComma())); return parser.parseRSquare(); @@ -602,7 +587,7 @@ static Type parseStructType(SPIRVDialect const &dialect, return StructType::getEmpty(dialect.getContext()); SmallVector memberTypes; - SmallVector offsetInfo; + SmallVector layoutInfo; SmallVector memberDecorationInfo; do { @@ -612,21 +597,21 @@ static Type parseStructType(SPIRVDialect const &dialect, memberTypes.push_back(memberType); if (succeeded(parser.parseOptionalLSquare())) { - if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, + if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo, memberDecorationInfo)) { return Type(); } } } while (succeeded(parser.parseOptionalComma())); - if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { + if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) { parser.emitError(parser.getNameLoc(), - "offset specification must be given for all members"); + "layout specification must be given for all members"); return Type(); } if (parser.parseGreater()) return Type(); - return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); + return StructType::get(memberTypes, layoutInfo, memberDecorationInfo); } // spirv-type ::= array-type @@ -694,20 +679,17 @@ static void print(StructType type, DialectAsmPrinter &os) { os << "struct<"; auto printMember = [&](unsigned i) { os << type.getElementType(i); - SmallVector decorations; + SmallVector decorations; type.getMemberDecorations(i, decorations); - if (type.hasOffset() || !decorations.empty()) { + if (type.hasLayout() || !decorations.empty()) { os << " ["; - if (type.hasOffset()) { - os << type.getMemberOffset(i); + if (type.hasLayout()) { + os << type.getOffset(i); if (!decorations.empty()) os << ", "; } - auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { - os << stringifyDecoration(decoration.decoration); - if (decoration.hasValue) { - os << "=" << decoration.decorationValue; - } + auto eachFn = [&os](spirv::Decoration decoration) { + os << stringifyDecoration(decoration); }; llvm::interleaveComma(decorations, os, eachFn); os << "]"; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 963f5393c572f8..0226f51175400e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -874,17 +874,17 @@ void SPIRVType::getCapabilities( struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, - StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, + StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo) : TypeStorage(numMembers), memberTypes(memberTypes), - offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations), + layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations), memberDecorationsInfo(memberDecorationsInfo) {} - using KeyTy = std::tuple, ArrayRef, + using KeyTy = std::tuple, ArrayRef, ArrayRef>; bool operator==(const KeyTy &key) const { return key == - KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo()); + KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo()); } static StructTypeStorage *construct(TypeStorageAllocator &allocator, @@ -897,13 +897,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { typesList = allocator.copyInto(keyTypes).data(); } - const StructType::OffsetInfo *offsetInfoList = nullptr; + const StructType::LayoutInfo *layoutInfoList = nullptr; if (!std::get<1>(key).empty()) { - ArrayRef keyOffsetInfo = std::get<1>(key); - assert(keyOffsetInfo.size() == keyTypes.size() && - "size of offset information must be same as the size of number of " + ArrayRef keyLayoutInfo = std::get<1>(key); + assert(keyLayoutInfo.size() == keyTypes.size() && + "size of layout information must be same as the size of number of " "elements"); - offsetInfoList = allocator.copyInto(keyOffsetInfo).data(); + layoutInfoList = allocator.copyInto(keyLayoutInfo).data(); } const StructType::MemberDecorationInfo *memberDecorationList = nullptr; @@ -914,7 +914,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } return new (allocator.allocate()) - StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, + StructTypeStorage(keyTypes.size(), typesList, layoutInfoList, numMemberDecorations, memberDecorationList); } @@ -922,9 +922,9 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { return ArrayRef(memberTypes, getSubclassData()); } - ArrayRef getOffsetInfo() const { - if (offsetInfo) { - return ArrayRef(offsetInfo, getSubclassData()); + ArrayRef getLayoutInfo() const { + if (layoutInfo) { + return ArrayRef(layoutInfo, getSubclassData()); } return {}; } @@ -938,14 +938,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { } Type const *memberTypes; - StructType::OffsetInfo const *offsetInfo; + StructType::LayoutInfo const *layoutInfo; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; }; StructType StructType::get(ArrayRef memberTypes, - ArrayRef offsetInfo, + ArrayRef layoutInfo, ArrayRef memberDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. @@ -953,12 +953,12 @@ StructType::get(ArrayRef memberTypes, memberDecorations.begin(), memberDecorations.end()); llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct, - memberTypes, offsetInfo, sortedDecorations); + memberTypes, layoutInfo, sortedDecorations); } StructType StructType::getEmpty(MLIRContext *context) { return Base::get(context, TypeKind::Struct, ArrayRef(), - ArrayRef(), + ArrayRef(), ArrayRef()); } @@ -975,11 +975,11 @@ StructType::ElementTypeRange StructType::getElementTypes() const { return ElementTypeRange(getImpl()->memberTypes, getNumElements()); } -bool StructType::hasOffset() const { return getImpl()->offsetInfo; } +bool StructType::hasLayout() const { return getImpl()->layoutInfo; } -uint64_t StructType::getMemberOffset(unsigned index) const { +uint64_t StructType::getOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); - return getImpl()->offsetInfo[index]; + return getImpl()->layoutInfo[index]; } void StructType::getMemberDecorations( @@ -992,16 +992,15 @@ void StructType::getMemberDecorations( } void StructType::getMemberDecorations( - unsigned index, - SmallVectorImpl &decorationsInfo) const { + unsigned index, SmallVectorImpl &decorations) const { assert(getNumElements() > index && "member index out of range"); auto memberDecorations = getImpl()->getMemberDecorationsInfo(); - decorationsInfo.clear(); - for (const auto &memberDecoration : memberDecorations) { - if (memberDecoration.memberIndex == index) { - decorationsInfo.push_back(memberDecoration); + decorations.clear(); + for (auto &memberDecoration : memberDecorations) { + if (memberDecoration.first == index) { + decorations.push_back(memberDecoration.second); } - if (memberDecoration.memberIndex > index) { + if (memberDecoration.first > index) { // Early exit since the decorations are stored sorted. return; } @@ -1021,12 +1020,6 @@ void StructType::getCapabilities( elementType.cast().getCapabilities(capabilities, storage); } -llvm::hash_code spirv::hash_value( - const StructType::MemberDecorationInfo &memberDecorationInfo) { - return llvm::hash_combine(memberDecorationInfo.memberIndex, - memberDecorationInfo.decoration); -} - //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 4bb9e6d26b97d2..ecd79d7153ab82 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1305,7 +1305,7 @@ LogicalResult Deserializer::processStructType(ArrayRef operands) { memberTypes.push_back(memberType); } - SmallVector offsetInfo; + SmallVector layoutInfo; SmallVector memberDecorationsInfo; if (memberDecorationMap.count(operands[0])) { auto &allMemberDecorations = memberDecorationMap[operands[0]]; @@ -1314,27 +1314,27 @@ LogicalResult Deserializer::processStructType(ArrayRef operands) { for (auto &memberDecoration : allMemberDecorations[memberIndex]) { // Check for offset. if (memberDecoration.first == spirv::Decoration::Offset) { - // If offset info is empty, resize to the number of members; - if (offsetInfo.empty()) { - offsetInfo.resize(memberTypes.size()); + // If layoutInfo is empty, resize to the number of members; + if (layoutInfo.empty()) { + layoutInfo.resize(memberTypes.size()); } - offsetInfo[memberIndex] = memberDecoration.second[0]; + layoutInfo[memberIndex] = memberDecoration.second[0]; } else { if (!memberDecoration.second.empty()) { - memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, - memberDecoration.first, - memberDecoration.second[0]); - } else { - memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, - memberDecoration.first, 0); + return emitError(unknownLoc, + "unhandled OpMemberDecoration with decoration ") + << stringifyDecoration(memberDecoration.first) + << " which has additional operands"; } + memberDecorationsInfo.emplace_back(memberIndex, + memberDecoration.first); } } } } } typeMap[operands[0]] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo); // TODO(ravishankarm): Update StructType to have member name as attribute as // well. return success(); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index f8641873fd958c..81f873281c6334 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -227,9 +227,9 @@ class Serializer { } /// Process member decoration - LogicalResult processMemberDecoration( - uint32_t structID, - const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); + LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex, + spirv::Decoration decorationType, + ArrayRef values = {}); //===--------------------------------------------------------------------===// // Types @@ -736,14 +736,14 @@ LogicalResult Serializer::processTypeDecoration( return success(); } -LogicalResult Serializer::processMemberDecoration( - uint32_t structID, - const spirv::StructType::MemberDecorationInfo &memberDecoration) { +LogicalResult +Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, + spirv::Decoration decorationType, + ArrayRef values) { SmallVector args( - {structID, memberDecoration.memberIndex, - static_cast(memberDecoration.decoration)}); - if (memberDecoration.hasValue) { - args.push_back(memberDecoration.decorationValue); + {structID, memberIndex, static_cast(decorationType)}); + if (!values.empty()) { + args.append(values.begin(), values.end()); } return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); @@ -1070,7 +1070,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, } if (auto structType = type.dyn_cast()) { - bool hasOffset = structType.hasOffset(); + bool hasLayout = structType.hasLayout(); for (auto elementIndex : llvm::seq(0, structType.getNumElements())) { uint32_t elementTypeID = 0; @@ -1079,12 +1079,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, return failure(); } operands.push_back(elementTypeID); - if (hasOffset) { + if (hasLayout) { // Decorate each struct member with an offset - spirv::StructType::MemberDecorationInfo offsetDecoration{ - elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, - static_cast(structType.getMemberOffset(elementIndex))}; - if (failed(processMemberDecoration(resultID, offsetDecoration))) { + if (failed(processMemberDecoration( + resultID, elementIndex, spirv::Decoration::Offset, + static_cast(structType.getOffset(elementIndex))))) { return emitError(loc, "cannot decorate ") << elementIndex << "-th member of " << structType << " with its offset"; @@ -1094,11 +1093,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, SmallVector memberDecorations; structType.getMemberDecorations(memberDecorations); for (auto &memberDecoration : memberDecorations) { - if (failed(processMemberDecoration(resultID, memberDecoration))) { + if (failed(processMemberDecoration(resultID, memberDecoration.first, + memberDecoration.second))) { return emitError(loc, "cannot decorate ") - << static_cast(memberDecoration.memberIndex) - << "-th member of " << structType << " with " - << stringifyDecoration(memberDecoration.decoration); + << memberDecoration.first << "-th member of " << structType + << " with " << stringifyDecoration(memberDecoration.second); } } typeEnum = spirv::Opcode::OpTypeStruct; diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir index fff591d2f24e31..3066462fd71b6b 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -22,9 +22,6 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @var6 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> - spv.globalVariable @var7 : !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @empty : !spv.ptr, StorageBuffer> diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir index d5eb073c9aa550..1d1a1868ea3c5f 100644 --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -275,23 +275,17 @@ func @struct_type_with_decoration7(!spv.struct>) func @struct_type_with_decoration8(!spv.struct>) -// CHECK: func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) -func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) - -// CHECK: func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) -func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) - // CHECK: func @struct_empty(!spv.struct<>) func @struct_empty(!spv.struct<>) // ----- -// expected-error @+1 {{offset specification must be given for all members}} +// expected-error @+1 {{layout specification must be given for all members}} func @struct_type_missing_offset1((!spv.struct) -> () // ----- -// expected-error @+1 {{offset specification must be given for all members}} +// expected-error @+1 {{layout specification must be given for all members}} func @struct_type_missing_offset2(!spv.struct) -> () // ----- @@ -336,16 +330,6 @@ func @struct_type_missing_comma(!spv.struct> [0, RowMajor MatrixStride=16]>) - -// ----- - -// expected-error @+1 {{expected integer value}} -func @struct_missing_member_decorator_value(!spv.struct> [0, RowMajor, MatrixStride=]>) - -// ----- - //===----------------------------------------------------------------------===// // CooperativeMatrix //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index 340bfd939e1b3d..06c417ca23f74a 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -58,8 +58,8 @@ class SerializationTest : public ::testing::Test { Type getFloatStructType() { OpBuilder opBuilder(module.body()); llvm::SmallVector elementTypes{opBuilder.getF32Type()}; - llvm::SmallVector offsetInfo{0}; - auto structType = spirv::StructType::get(elementTypes, offsetInfo); + llvm::SmallVector layoutInfo{0}; + auto structType = spirv::StructType::get(elementTypes, layoutInfo); return structType; }