Skip to content

Commit

Permalink
[MHLO] support non-constant torch scalar in BasicOps (#1134)
Browse files Browse the repository at this point in the history
See RFC #999

Co-authored-by: Bairen Yi yibairen.byron@bytedance.com
Co-authored-by: Jiawei Wu xremold@gmail.com
Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com
Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com
Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com
  • Loading branch information
Tanyo Kwok committed Aug 3, 2022
1 parent 82af44d commit 0b23af2
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 350 deletions.
42 changes: 13 additions & 29 deletions lib/Conversion/TorchToMhlo/BasicOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,15 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
}

if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
outElemTy, {})))
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
}

lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);

if (!skipMultiplyAlpha(op.alpha())) {
Value alpha;
if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(),
op.alpha(), alpha, outElemTy, {},
/*checkForUnity=*/false))) {
return op.emitError("currently only scalar constants are supported for "
"alpha in conversion to MHLO operation");
}
Value alpha =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
Expand Down Expand Up @@ -216,13 +208,13 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
outElemTy, {})))
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
}

Value lhsTensor = lhs;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
Expand Down Expand Up @@ -263,11 +255,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
}

if (!rhsTy) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
lhsElemTy, {}))) {
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
}
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy);
}

// TODO: what is the PyTorch default type promotion?
Expand Down Expand Up @@ -569,12 +557,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
.cast<RankedTensorType>();
auto outputShape = outputType.getShape();
auto outputElemType = outputType.getElementType();
Value mhloTensor;
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor,
outputElemType, outputShape,
false))) {
return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO");
}
Value mhloTensor =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
rewriter.replaceOp(op, mhloTensor);
return success();
}
Expand Down Expand Up @@ -1020,4 +1004,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
#undef INSERT_ATENOP_PATTERN
}
}
98 changes: 10 additions & 88 deletions lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,93 +174,15 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
return const_op.getResult();
}

// TODO: Support for variable scalar.
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &mhloTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool doBroadcast) {
// Retrieve a const float or int value but create the out Tensor with dtype.
double doubleValue;
auto isFloat =
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));

int64_t intValue;
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));

if (!isFloat && !isInt)
return op->emitError("Unable to extract the scalar constant");

if (dtype.isa<mlir::FloatType>()) {
if (doBroadcast) {
mhloTensor = getSplatConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
} else {
mhloTensor = mhlo::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
.getValue();
}
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
auto w = intType.getWidth();
if (w != 32 && w != 64)
return op->emitError("Unsupported integer type") << intType;

if (w == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return op->emitError("Supplied value of scalar constant exceeds limits "
"of destination type");
}
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue);
if (doBroadcast) {
mhloTensor =
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
}
} else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return op->emitError("Supplied value of scalar constant exceeds limits "
"of destination type");
}
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
if (doBroadcast) {
mhloTensor =
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
}
}
} else
return op->emitError("Usupported element type");

return success();
}

LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool checkForUnity) {
if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor,
dtype, dshape)))
return success();

// `alpha` has not been specified.
int64_t alphaValue;
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
return op->emitError("Currently only scalar constants are supported for "
"alpha in MHLO operation");
// When no alpha has been specified, this must be 1.
if (checkForUnity && alphaValue != 1)
return op->emitError("Unsupported integer value for alpha");

alphaTensor =
mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue);

return success();
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype) {
auto tensor = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ArrayRef<Value>{scalarValue});
auto dtype_tensor =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<mhlo::ReshapeOp>(
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
dtype_tensor);
}

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
Expand Down Expand Up @@ -439,4 +361,4 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
.getResult();
}
} // namespace mhlo
} // namespace mlir
} // namespace mlir
13 changes: 2 additions & 11 deletions lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,8 @@ template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);

LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &mhloTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool doBroadcast = true);

LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool checkForUnity);
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);

Expand Down
16 changes: 10 additions & 6 deletions test/Conversion/TorchToMhlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {

// -----

// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic(
// CHECK-SAME: ) -> !torch.vtensor<[],si64> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>
Expand Down Expand Up @@ -251,4 +255,4 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
%2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list<int>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32>
return %result0 : !torch.vtensor<[3,7,4,5],f32>
}
}
Loading

0 comments on commit 0b23af2

Please sign in to comment.