Skip to content

Commit

Permalink
Fix lowering of reduce ops
Browse files Browse the repository at this point in the history
We were not filling the `outs` with the neutral element of the
reduction, which resulted in reading uninitialized values (we were
getting lucky that sometimes the uninitialized buffers were all zero's).

Also,
- Slight tweak to error messages in the e2e framework.
  • Loading branch information
silvasean committed Sep 8, 2021
1 parent 6724de7 commit 5f3eb63
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def __init__(self, tensor):
self.min = torch.min(tensor)
self.max = torch.max(tensor)
self.mean = torch.mean(tensor)
self.shape = list(tensor.shape)

def __str__(self):
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'


class ErrorContext:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_recursive(self):

# CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch"
# CHECK-NEXT: @ output of call to "test_tensor_value_mismatch"
# CHECK-NEXT: ERROR: value (Tensor with min=+1.0, max=+3.0, mean=+2.0000) is not close to golden value (Tensor with min=+1.5, max=+3.5, mean=+2.5000)
# CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5)
@torch.jit.export
def test_tensor_value_mismatch(self):
if torch.jit.is_scripting():
Expand Down
21 changes: 19 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}

static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
Operation *op,
Type elementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>())
return b.create<mlir::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));

op->emitError("unimplemented lowering in "
"createLinalgNeutralElementForReduceOp");
return nullptr;
}

static Value createLinalgPayloadCalculationForReduceOp(
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
ArrayRef<Value> operands, Type elementType) {
Expand Down Expand Up @@ -981,11 +993,16 @@ struct ConvertReductionOp : ConversionPattern {
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultShape, resultType.getElementType());
Value initValue = createLinalgNeutralElementForReduceOp(
rewriter, loc, op, resultType.getElementType());
Value accumulator =
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
.getResult(0);
bool hadErrorCreatingPayload = false;
auto generic = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/initTensor.getType(),
loc, /*resultTensorTypes=*/accumulator.getType(),
/*inputs=*/tensorOperand,
/*outputs=*/initTensor,
/*outputs=*/accumulator,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Expand Down

0 comments on commit 5f3eb63

Please sign in to comment.