Skip to content

Commit

Permalink
add shape inference and op filter list (#871)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzpmiracle authored Dec 18, 2022
1 parent d09e4a9 commit f378a38
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ std::shared_ptr<SchemaSet> nn_ops_first_input_preserving() {
#if PYTORCH_MAJOR_VERSION == 1 && PYTORCH_MINOR_VERSION > 6
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
#endif
"aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor",
"aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
"aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
Expand Down
32 changes: 32 additions & 0 deletions pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ class ShapePropagator : public PropertyPropBase {
"aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
"aten::glu(Tensor self, int dim) -> Tensor",
"aten::inverse(Tensor self) -> Tensor",
"aten::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, bool cudnn_enabled) -> Tensor",
"aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
"aten::leaky_relu_(Tensor self, Scalar negative_slope) -> Tensor",
"aten::lgamma(Tensor self) -> Tensor",
Expand Down Expand Up @@ -1278,6 +1279,37 @@ class ShapePropagator : public PropertyPropBase {
return {};
}};

// Requirements:
// device : Device
// tensor inputs : 1
// tensor outputs : 1
// Additionally:
// - First input should be the only tensor input
static const register_formula_for aten_to_device{
{"aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"},
[](Node* node) -> type_vec_t {
at::optional<IValue> maybe_device_option = node->get(attr::device);
if (auto type = node->input(0)->type()->cast<TensorType>()) {
auto ret = type;
if (maybe_device_option && !maybe_device_option->isNone()) {
auto device = maybe_device_option->toDevice();
#if PYTORCH_VERSION_GE(1, 11)
return {ret->withDevice(device)};
#else
return {TensorType::create(
ret->scalarType(),
device,
ret->sizes(),
ret->strides(),
/*requires_grad=*/c10::nullopt)};
#endif
} else {
return {ret};
}
}
return {};
}};

// Requirements:
// dims : 0
// scalar type : preserved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,21 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::addmm",
"aten::arange",
"aten::unbind",
"aten::baddmm",
"aten::batch_norm",
"aten::bitwise_not",
"aten::bmm",
"aten::cat",
"aten::chunk",
"aten::contiguous",
"aten::_convolution",
"aten::convolution",
"aten::conv1d",
"aten::conv2d",
"aten::cos",
"aten::div",
"aten::einsum",
"aten::embedding",
"aten::empty",
"aten::eq",
"aten::gt",
"aten::ge",
Expand Down Expand Up @@ -107,6 +109,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::select",
"aten::sigmoid",
"aten::silu",
"aten::sin",
"aten::size",
"aten::slice",
"aten::softmax",
Expand All @@ -120,6 +123,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::tensor",
"aten::to",
"aten::to.dtype",
"aten::to.device",
"aten::transpose",
"aten::type_as",
"aten::unsqueeze",
Expand All @@ -128,6 +132,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"prim::Constant",
"prim::ListConstruct",
"prim::ListUnpack",
"prim::NumToTensor",
// Torch Blade custom ops follows:
"aten::add_inplace", // use aten namespace to work with PyTorch mutation pass
"aten::sub_inplace", // use aten namespace to work with PyTorch mutation pass
Expand Down
27 changes: 27 additions & 0 deletions pytorch_blade/tests/torchscript/basics.graph
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,30 @@ graph(%p1 : Float(*, *, *, device=cuda:0),
// CHECK: Float(*, *, *, device=cuda:0) = aten::tanh_backward(%p1, %p2)
%1 : Tensor = aten::tanh_backward(%p1, %p2)
return (%1)


// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, *, device=cuda:0),
%p2 : Float(*, *, *, *, device=cuda:0),
%p3 : Float(*, device=cuda:0)
):
%3 : int[] = prim::Constant[value=[1, 1]]()
%4 : bool = prim::Constant[value=0]()
%5 : int[] = prim::Constant[value=[0, 0]]()
%6 : int = prim::Constant[value=1]()
// CHECK: Float(*, *, *, *, device=cuda:0) = aten::_convolution(%p1, %p2, %p3, %3, %3, %3, %4, %5, %6, %4, %4, %4, %4)
%7 : Tensor = aten::_convolution(%p1, %p2, %p3, %3, %3, %3, %4, %5, %6, %4, %4, %4, %4)
return (%7)

// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, *, device=cuda:0),
%p2 : Float(*, *, *, *, device=cuda:0),
%p3 : Float(*, device=cuda:0)
):
%3 : int[] = prim::Constant[value=[1, 1]]()
%4 : bool = prim::Constant[value=0]()
%5 : int[] = prim::Constant[value=[0, 0]]()
%6 : int = prim::Constant[value=1]()
// CHECK: Float(*, *, *, *, device=cuda:0) = aten::convolution(%p1, %p2, %p3, %3, %3, %3, %4, %5, %6)
%7 : Tensor = aten::convolution(%p1, %p2, %p3, %3, %3, %3, %4, %5, %6)
return (%7)
19 changes: 19 additions & 0 deletions pytorch_blade/tests/torchscript/since_1_10.graph
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,22 @@ graph(%p1 : Half(*, *)):
%1 : Tensor = aten::as_strided(%p1, %size, %strides, %none)
return (%1)

// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, device=cpu)):
%1 : Device = prim::Constant[value="cuda:0"]()
%2 : int = prim::Constant[value=5]()
%3 : bool = prim::Constant[value=0]()
%4 : NoneType = prim::Constant()
// CHECK: Float(*, *, *, device=cuda:0) = aten::to(%p1, %1, %2, %3, %3, %4)
%3 : Tensor = aten::to(%p1, %1, %2, %3, %3, %4)
return (%3)

// CHECK-LABEL: graph
graph(%p1 : Float(20, 30, 40, device=cpu)):
%1 : Device = prim::Constant[value="cuda:0"]()
%2 : int = prim::Constant[value=5]()
%3 : bool = prim::Constant[value=0]()
%4 : NoneType = prim::Constant()
// CHECK: Float(20, 30, 40, device=cuda:0) = aten::to(%p1, %1, %2, %3, %3, %4)
%3 : Tensor = aten::to(%p1, %1, %2, %3, %3, %4)
return (%3)
8 changes: 8 additions & 0 deletions pytorch_blade/tests/torchscript/slice_like.graph
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,11 @@ graph(%p1 : Float(8, 768, 512)):
%slice : Tensor = aten::slice(%p1, %cst_1, %cst0, %int_max, %cst1)
return (%slice)

// CHECK-LABEL: graph
graph(%p1 : Float(1, 512, requires_grad=0, device=cpu),
%p2 : int):
%2 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=1]()
// CHECK: Float(1, *, requires_grad=0, device=cpu) = aten::slice(%p1, %3, %2, %p2, %3)
%4 : Tensor = aten::slice(%p1, %3, %2, %p2, %3)
return (%4)

0 comments on commit f378a38

Please sign in to comment.