Skip to content

Commit

Permalink
llvm: update tag to 061e0189 (#1180)
Browse files Browse the repository at this point in the history
Summary of changes:
 - Switch to C++17 (similar to https://reviews.llvm.org/D131348)
 - Update MHLO to build with LLVM commit hash 061e0189
 - Replace deprecated `hasValue()` and `getValue()` with `has_value()`
   and `value()` respectively (https://reviews.llvm.org/D131349)
 - Use `TypedAttr` (https://reviews.llvm.org/D130092)
 - Use updated assembly format of `mhlo.compare` op (commit
   d03ef01e70fbf9afd0fa1976fbb7ed31838929b3 in MHLO repo)
  • Loading branch information
ashay committed Aug 9, 2022
1 parent 3e97a33 commit bb47c16
Show file tree
Hide file tree
Showing 21 changed files with 106 additions and 108 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ endif()

project(torch-mlir LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)

macro(torch_mlir_add_llvm_external_project name identifier location)
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 2357 files
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 63 files
+3 −0 .clang-tidy
+48 −0 BUILD
+2 −2 WORKSPACE
+1 −1 build_tools/llvm_version.txt
+5 −3 include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td
+5 −0 include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
+24 −0 include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h
+195 −0 include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td
+14 −3 include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+3 −3 include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
+11 −13 lib/CAPI/Attributes.cc
+221 −37 lib/Dialect/gml_st/IR/gml_st_ops.cc
+11 −0 lib/Dialect/gml_st/transforms/CMakeLists.txt
+2 −0 lib/Dialect/gml_st/transforms/fusion.cc
+239 −78 lib/Dialect/gml_st/transforms/fusion_interface_impl.cc
+78 −5 lib/Dialect/gml_st/transforms/legalize_mhlo_to_gml.cc
+2 −2 lib/Dialect/gml_st/transforms/tiling.cc
+24 −0 lib/Dialect/gml_st/transforms/tiling_interface.cc
+23 −10 lib/Dialect/gml_st/transforms/transforms.cc
+3 −3 lib/Dialect/lhlo/IR/lhlo_ops.cc
+2 −2 lib/Dialect/lhlo/transforms/lhlo_legalize_to_parallel_loops.cc
+8 −5 lib/Dialect/mhlo/IR/chlo_ops.cc
+316 −191 lib/Dialect/mhlo/IR/hlo_ops.cc
+1 −1 lib/Dialect/mhlo/IR/hlo_ops_common.cc
+1 −1 lib/Dialect/mhlo/transforms/broadcast_propagation.cc
+2 −1 lib/Dialect/mhlo/transforms/constraint_fusion_pass.cc
+2 −2 lib/Dialect/mhlo/transforms/group_reduction_dimensions.cc
+1 −1 lib/Dialect/mhlo/transforms/legalize_control_flow.cc
+2 −2 lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc
+2 −2 lib/Dialect/mhlo/transforms/legalize_sort.cc
+407 −50 lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+4 −8 lib/Dialect/mhlo/transforms/legalize_to_standard.cc
+3 −2 lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc
+4 −2 lib/Dialect/mhlo/transforms/lower_general_dot.cc
+5 −3 lib/Dialect/mhlo/transforms/optimize_mhlo.cc
+1 −1 lib/Dialect/mhlo/transforms/type_conversion.cc
+3 −3 lib/Transforms/buffer_packing.cc
+1 −1 lib/Transforms/bufferize.cc
+1 −1 lib/Transforms/copy_removal.cc
+11 −2 lib/Transforms/gpu_fusion_rewrite.cc
+4 −0 lib/Transforms/hlo_to_gpu_pipeline.cc
+1 −1 lib/Transforms/propagate_static_shapes_to_kernel.cc
+3 −2 lib/Transforms/symbolic_shape_optimization.cc
+2 −2 lib/Transforms/tile_loops_pass.cc
+2 −2 tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir
+86 −86 tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
+144 −1 tests/Dialect/gml_st/fusion.mlir
+38 −0 tests/Dialect/gml_st/legalize_mhlo_to_gml.mlir
+63 −18 tests/Dialect/gml_st/ops.mlir
+212 −0 tests/Dialect/gml_st/tiling_and_fusion.mlir
+3 −4 tests/Dialect/mhlo/canonicalize/canonicalize.mlir
+66 −35 tests/Dialect/mhlo/canonicalize/convolution.mlir
+65 −1 tests/Dialect/mhlo/canonicalize/reverse.mlir
+148 −0 tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+6 −6 tests/Dialect/mhlo/legalize-control-flow.mlir
+4 −4 tests/Dialect/mhlo/lower-complex.mlir
+20 −0 tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir
+91 −28 tests/Dialect/mhlo/ops.mlir
+1 −1 tests/Dialect/mhlo/sink-constants-to-control-flow.mlir
+55 −1 tests/gpu_fusion_rewrite.mlir
+1 −2 tests/propagate_static_shapes.mlir
+4 −4 tests/rank-specialization.mlir
+1 −1 tosa/CMakeLists.txt
4 changes: 2 additions & 2 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ class BaseTensorType : public Type {
Type getOptionalDtype() const;

/// Return true if this type has a list of sizes.
bool hasSizes() const { return getOptionalSizes().hasValue(); }
bool hasSizes() const { return getOptionalSizes().has_value(); }

/// Get the list of sizes. Requires `hasSizes()`.
ArrayRef<int64_t> getSizes() const {
assert(hasSizes() && "must have sizes");
return getOptionalSizes().getValue();
return getOptionalSizes().value();
}

/// Return true if all sizes of this tensor are known.
Expand Down
4 changes: 2 additions & 2 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
AttrOrTypeParameter<
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
let allocator = [{
if ($_self.hasValue()) {
$_dst.getValue() = $_allocator.copyInto($_self.getValue());
if ($_self.has_value()) {
$_dst.value() = $_allocator.copyInto($_self.value());
}
}];
}
Expand Down
3 changes: 2 additions & 1 deletion lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
}

MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType = unwrap(attr).getType().cast<RankedTensorType>();
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
Expand Down
12 changes: 6 additions & 6 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
continue;
}

if (inferredDimension.hasValue()) {
if (inferredDimension.has_value()) {
return rewriter.notifyMatchFailure(
op, "at most one element in size list is allowed to be -1");
}
Expand All @@ -363,7 +363,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
// then we don't need to analyze the static information of the input
// shape since the reassociation of dimensions only requires rank
// information.
if (inferredDimension.hasValue() && outputShape.size() > 1) {
if (inferredDimension.has_value() && outputShape.size() > 1) {
if (llvm::count(outputShape, kUnknownSize) != 1 ||
llvm::count(inputShape, kUnknownSize) != 0) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -585,14 +585,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
collapsedInput = rewriter
.create<tensor::ExpandShapeOp>(
loc, adjustedResultType,
expandedInput.hasValue() ? expandedInput.value()
: castedInput,
expandedInput.has_value() ? expandedInput.value()
: castedInput,
outputAssociations)
.result();
}

Value result = collapsedInput.hasValue() ? collapsedInput.value()
: expandedInput.value();
Value result = collapsedInput.has_value() ? collapsedInput.value()
: expandedInput.value();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {

SmallVector<int32_t> values(size, fillVal);
auto constOp =
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();

rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
return success();
Expand Down Expand Up @@ -884,7 +884,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
.getValue());
.value());

// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec(
Expand Down Expand Up @@ -920,19 +920,19 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())})
.getValue());
.value());
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
mhlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.getValue());
.value());
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
mhlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.getValue());
.value());

// Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TorchToMhlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
initIndexTensor, inputShapeTensor)
.getResult();

Value initIdx =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();

auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
Expand Down Expand Up @@ -491,7 +490,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
if (countIncludePad) {
Value divisor = mhlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.getValue();
.value();
divisor = mhlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
Expand All @@ -501,7 +500,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(

// Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
if (!initValue) return llvm::None;

Value initIndex =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();

DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
RankedTensorType::get({}, rewriter.getI64Type()), dim);
Expand Down Expand Up @@ -224,7 +224,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();

if (keepDim) {
auto outShapeVec = inputShapeVec;
Expand Down Expand Up @@ -301,7 +301,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();

if (keepDim) {
auto outShapeVec = inputShapeVec;
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TorchToStd/TorchToStd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,15 @@ class ConvertTorchTensorLiteralOp
}));
return success();
}
if (auto elements = op.valueAttr().dyn_cast<OpaqueElementsAttr>()) {
if (auto elements = op.valueAttr().dyn_cast<SparseElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
Type builtinTensorElemTy =
IntegerType::get(context, intType.getIntOrFloatBitWidth());
auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, OpaqueElementsAttr::get(elements.getDialect(), shapedType,
elements.getValue()));
op, DenseElementsAttr::get(shapedType, elements.getValues()));
return success();
}
}
Expand Down
Loading

0 comments on commit bb47c16

Please sign in to comment.