From f27ee77617c10833dcbf755306e2a5580dd3d2f5 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Sun, 9 Jan 2022 15:21:35 -0800 Subject: [PATCH] [Rust] Update Rust bindings (#9808) * Update Rust bindings * fmt Co-authored-by: AD1024 --- include/tvm/relay/attrs/nn.h | 2 +- include/tvm/relay/attrs/transform.h | 4 +- rust/tvm/src/ir/relay/attrs/mod.rs | 1 + rust/tvm/src/ir/relay/attrs/nn.rs | 62 ++++++++++++++++++++++++ rust/tvm/src/ir/relay/attrs/reduce.rs | 48 ++++++++++++++++++ rust/tvm/src/ir/relay/attrs/transform.rs | 59 +++++++++++++++++++++- rust/tvm/src/ir/relay/mod.rs | 4 +- 7 files changed, 175 insertions(+), 5 deletions(-) create mode 100644 rust/tvm/src/ir/relay/attrs/reduce.rs diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 26d2c72c824d..e9f3552870a5 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1232,7 +1232,7 @@ struct UpSampling3DAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { Array> pad_width; - std::string pad_mode; + tvm::String pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { TVM_ATTR_FIELD(pad_width).describe( diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 28f91723696c..07723df56925 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -173,7 +173,7 @@ struct GatherNDAttrs : public tvm::AttrsNode { struct TakeAttrs : public tvm::AttrsNode { Integer batch_dims; Integer axis; - std::string mode; + tvm::String mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { TVM_ATTR_FIELD(batch_dims) @@ -329,7 +329,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Optional> begin; Optional> end; Optional> strides; - std::string slice_mode; + tvm::String slice_mode; Optional> axes; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { diff --git a/rust/tvm/src/ir/relay/attrs/mod.rs b/rust/tvm/src/ir/relay/attrs/mod.rs index d1bcc0009657..333ed26752fc 100644 --- a/rust/tvm/src/ir/relay/attrs/mod.rs +++ b/rust/tvm/src/ir/relay/attrs/mod.rs @@ -18,4 +18,5 @@ */ pub mod nn; +pub mod reduce; pub mod transform; diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index f003ae627aec..e1c572ae3451 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -26,6 +26,35 @@ use tvm_macros::Object; type IndexExpr = PrimExpr; +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "PadAttrs"] +#[type_key = "relay.attrs.PadAttrs"] +pub struct PadAttrsNode { + pub base: BaseAttrsNode, + pub pad_width: Array>, + pub pad_mode: TString, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "Conv1DAttrs"] +#[type_key = "relay.attrs.Conv1DAttrs"] +pub struct Conv1DAttrsNode { + pub base: BaseAttrsNode, + pub strides: Array, + pub padding: Array, + pub dilation: Array, + // TODO(@gussmith23) groups is "int", what should it be here? + pub groups: i32, + pub channels: IndexExpr, + pub kernel_size: Array, + pub data_layout: TString, + pub kernel_layout: TString, + pub out_layout: TString, + pub out_dtype: DataType, +} + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "Conv2DAttrs"] @@ -42,6 +71,7 @@ pub struct Conv2DAttrsNode { pub data_layout: TString, pub kernel_layout: TString, pub out_layout: TString, + pub auto_scheduler_rewritten_layout: TString, pub out_dtype: DataType, } @@ -138,6 +168,7 @@ pub struct AvgPool2DAttrsNode { pub pool_size: Array, pub strides: Array, pub padding: Array, + pub dilation: Array, pub layout: TString, pub ceil_mode: bool, pub count_include_pad: bool, @@ -155,3 +186,34 @@ pub struct UpSamplingAttrsNode { pub method: TString, pub align_corners: bool, } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "DropoutAttrs"] +#[type_key = "relay.attrs.DropoutAttrs"] +pub struct DropoutAttrsNode { + pub base: BaseAttrsNode, + pub rate: f64, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "BatchMatmulAttrs"] +#[type_key = "relay.attrs.BatchMatmulAttrs"] +pub struct BatchMatmulAttrsNode { + pub base: BaseAttrsNode, + pub auto_scheduler_rewritten_layout: TString, + pub out_dtype: DataType, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "LayerNormAttrs"] +#[type_key = "relay.attrs.LayerNormAttrs"] +pub struct LayerNormAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, + pub epsilon: f64, + pub center: bool, + pub scale: bool, +} diff --git a/rust/tvm/src/ir/relay/attrs/reduce.rs b/rust/tvm/src/ir/relay/attrs/reduce.rs new file mode 100644 index 000000000000..aed84fdf2aad --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/reduce.rs @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::ir::attrs::BaseAttrsNode; +use crate::ir::PrimExpr; +use crate::runtime::array::Array; +use tvm_macros::Object; + +type IndexExpr = PrimExpr; + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ReduceAttrs"] +#[type_key = "relay.attrs.ReduceAttrs"] +pub struct ReduceAttrsNode { + pub base: BaseAttrsNode, + pub axis: Array, + pub keepdims: bool, + pub exclude: bool, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "VarianceAttrs"] +#[type_key = "relay.attrs.ReduceAttrs"] +pub struct VarianceAttrsNode { + pub base: BaseAttrsNode, + pub axis: Array, + pub keepdims: bool, + pub exclude: bool, + pub unbiased: bool, +} diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index b5f7c2047d62..d86c46a6f6bb 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -18,13 +18,35 @@ */ use crate::ir::attrs::BaseAttrsNode; +use crate::ir::relay::TString; +use crate::ir::tir::IntImm; use crate::ir::PrimExpr; use crate::runtime::array::Array; use crate::runtime::ObjectRef; use tvm_macros::Object; +use tvm_rt::DataType; type IndexExpr = PrimExpr; +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ClipAttrs"] +#[type_key = "relay.attrs.ClipAttrs"] +pub struct ClipAttrsNode { + pub base: BaseAttrsNode, + pub a_min: f64, + pub a_max: f64, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "CastAttrs"] +#[type_key = "relay.attrs.CastAttrs"] +pub struct CastAttrsNode { + pub base: BaseAttrsNode, + pub dtype: DataType, +} + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "ExpandDimsAttrs"] @@ -79,5 +101,40 @@ pub struct TransposeAttrsNode { #[type_key = "relay.attrs.SqueezeAttrs"] pub struct SqueezeAttrsNode { pub base: BaseAttrsNode, - pub axis: Array, + pub axis: Array, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "TakeAttrs"] +#[type_key = "relay.attrs.TakeAttrs"] +pub struct TakeAttrsNode { + pub base: BaseAttrsNode, + pub batch_dims: IntImm, + pub axis: IntImm, + pub mode: TString, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "StackAttrs"] +#[type_key = "relay.attrs.StackAttrs"] +pub struct StackAttrsNode { + pub base: BaseAttrsNode, + pub axis: IntImm, +} + +// TODO(@gussmith23) How to support Optional type? This "just works" when values +// are provided for begin/end/strides, but I'm not sure what happens if None is +// passed from the C++ side. +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "StridedSliceAttrs"] +#[type_key = "relay.attrs.StridedSliceAttrs"] +pub struct StridedSliceAttrsNode { + pub base: BaseAttrsNode, + pub begin: Array, + pub end: Array, + pub strides: Array, + pub slice_mode: TString, } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 404cca4946fb..abc25e89c48c 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -17,7 +17,7 @@ * under the License. */ use crate::runtime::array::Array; -use crate::runtime::{object::*, IsObjectRef, String as TString}; +use crate::runtime::{self, object::*, IsObjectRef, String as TString}; use super::attrs::Attrs; use super::expr::BaseExprNode; @@ -150,6 +150,7 @@ impl Var { #[type_key = "relay.Call"] pub struct CallNode { pub base: ExprNode, + deleter: ObjectRef, pub op: Expr, pub args: Array, pub attrs: Attrs, @@ -166,6 +167,7 @@ impl Call { ) -> Call { let node = CallNode { base: ExprNode::base::(span), + deleter: todo!("Don't know how to construct this"), op: op, args: args, attrs: attrs,