Skip to content

Commit

Permalink
[Torch] Fix AtenSliceTensorOp::fold (llvm#3345)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Yang authored and Branko Trifkovic committed May 24, 2024
1 parent 1a3299f commit e507f30
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3570,17 +3570,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
auto outType = dyn_cast<ValueTensorType>(getResult().getType());

if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype())
return nullptr;

if (start && end && step && step.getValue().getSExtValue() == 1 &&
start.getValue().getSExtValue() == 0 &&
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
inType == outType)
return getOperand(0);

if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype())
return nullptr;

if (inType.getSizes().size() != outType.getSizes().size() ||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
return nullptr;
Expand Down

0 comments on commit e507f30

Please sign in to comment.