From 5f3eb637c40f24f86a85d90e6fd031c4c788cb39 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 8 Sep 2021 21:58:15 +0000 Subject: [PATCH] Fix lowering of reduce ops We were not filling the `outs` with the neutral element of the reduction, which resulted in reading uninitialized values (we were getting lucky that sometimes the uninitialized buffers were all zero's). Also, - Slight tweak to error messages in the e2e framework. --- .../e2e_test/reporting.py | 3 ++- .../torchscript_e2e_test/error_reports.py | 2 +- .../TorchToLinalg/TorchToLinalg.cpp | 21 +++++++++++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py b/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py index b0b523b0d79..9c0fa936916 100644 --- a/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py +++ b/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py @@ -22,9 +22,10 @@ def __init__(self, tensor): self.min = torch.min(tensor) self.max = torch.max(tensor) self.mean = torch.mean(tensor) + self.shape = list(tensor.shape) def __str__(self): - return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}' + return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}' class ErrorContext: diff --git a/frontends/pytorch/test/torchscript_e2e_test/error_reports.py b/frontends/pytorch/test/torchscript_e2e_test/error_reports.py index a87ac800289..00e4f5295e2 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/error_reports.py +++ b/frontends/pytorch/test/torchscript_e2e_test/error_reports.py @@ -116,7 +116,7 @@ def test_recursive(self): # CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch" # CHECK-NEXT: @ output of call to "test_tensor_value_mismatch" - # CHECK-NEXT: ERROR: value (Tensor with min=+1.0, max=+3.0, mean=+2.0000) is not close to golden value (Tensor with min=+1.5, max=+3.5, mean=+2.5000) + # CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5) @torch.jit.export def test_tensor_value_mismatch(self): if torch.jit.is_scripting(): diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index cfde7595b5e..c5c5eb71eaf 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -755,6 +755,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } +static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc, + Operation *op, + Type elementType) { + if (isa(op) && + elementType.isa()) + return b.create(loc, b.getFloatAttr(elementType, 0.0)); + + op->emitError("unimplemented lowering in " + "createLinalgNeutralElementForReduceOp"); + return nullptr; +} + static Value createLinalgPayloadCalculationForReduceOp( OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, ArrayRef operands, Type elementType) { @@ -981,11 +993,16 @@ struct ConvertReductionOp : ConversionPattern { auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); Value initTensor = rewriter.create( loc, resultShape, resultType.getElementType()); + Value initValue = createLinalgNeutralElementForReduceOp( + rewriter, loc, op, resultType.getElementType()); + Value accumulator = + rewriter.create(loc, initValue, initTensor) + .getResult(0); bool hadErrorCreatingPayload = false; auto generic = rewriter.create( - loc, /*resultTensorTypes=*/initTensor.getType(), + loc, /*resultTensorTypes=*/accumulator.getType(), /*inputs=*/tensorOperand, - /*outputs=*/initTensor, + /*outputs=*/accumulator, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {