Skip to content

Commit

Permalink
MT model compilation minor changes
Browse files Browse the repository at this point in the history
This contains the following changes:
 - Fix optional knowledge propagation. The initial knowledge should
 always be NotNone for the operations we implemented.
 - Add Folder for `prim.dtype`
  • Loading branch information
cathyzhyi committed Sep 9, 2021
1 parent 5f3eb63 commit 200e8fa
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def emit(key, **kwargs):
emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)")
emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::min.self_int : (int[]) -> (int)")
Expand Down
50 changes: 34 additions & 16 deletions include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,40 @@ def Torch_AtenTriu_Op : Torch_Op<"aten.triu_", [
let assemblyFormat = "$self `,` $diagonal attr-dict `:` type($self) `,` type($diagonal) `->` type($result)";
}

def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
}

def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
}

def Torch_AtenLinearOp : Torch_Op<"aten.linear", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down Expand Up @@ -1404,22 +1438,6 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
let assemblyFormat = "$self `,` $indices attr-dict `:` type($self) `,` type($indices) `->` type($result)";
}

def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)";
}

def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
1 change: 1 addition & 0 deletions include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let hasFolder = 1;
}

def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;

static int64_t getDtypeIntegerFromMlirType(Type dtype) {
if (dtype.isa<Float32Type>())
return 6;

if (auto integerType = dtype.dyn_cast<IntegerType>()) {
if (integerType.isSignedInteger(64))
return 4;
if (integerType.isSignlessInteger(1))
return 11;
}
return -1;
}

//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -129,6 +142,10 @@ bool isValidSubtype(Type subtype, Type type) {
type ==
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;

if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
return false;
}

Expand Down Expand Up @@ -972,5 +989,18 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// PrimDtypeOp
//===----------------------------------------------------------------------===//

OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
BaseTensorType tensorType = a().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) {
int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype());
if (dtypeInt != -1)
return getI64IntegerAttr(getContext(), dtypeInt);
}
return nullptr;
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
Loading

0 comments on commit 200e8fa

Please sign in to comment.