Skip to content

Commit

Permalink
Update Rust bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
AD1024 authored and gussmith23 committed Dec 30, 2021
1 parent ce108c1 commit 3657700
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 5 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ struct UpSampling3DAttrs : public tvm::AttrsNode<UpSampling3DAttrs> {
/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
Array<Array<Integer>> pad_width;
std::string pad_mode;
tvm::String pad_mode;

TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_width).describe(
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
std::string mode;
tvm::String mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(batch_dims)
Expand Down Expand Up @@ -321,7 +321,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Optional<Array<Integer>> begin;
Optional<Array<Integer>> end;
Optional<Array<Integer>> strides;
std::string slice_mode;
tvm::String slice_mode;
Optional<Array<Integer>> axes;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
Expand Down
1 change: 1 addition & 0 deletions rust/tvm/src/ir/relay/attrs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
*/

pub mod nn;
pub mod reduce;
pub mod transform;
62 changes: 62 additions & 0 deletions rust/tvm/src/ir/relay/attrs/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,35 @@ use tvm_macros::Object;

type IndexExpr = PrimExpr;

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "PadAttrs"]
#[type_key = "relay.attrs.PadAttrs"]
pub struct PadAttrsNode {
pub base: BaseAttrsNode,
pub pad_width: Array<Array<IndexExpr>>,
pub pad_mode: TString,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "Conv1DAttrs"]
#[type_key = "relay.attrs.Conv1DAttrs"]
pub struct Conv1DAttrsNode {
pub base: BaseAttrsNode,
pub strides: Array<IndexExpr>,
pub padding: Array<IndexExpr>,
pub dilation: Array<IndexExpr>,
// TODO(@gussmith23) groups is "int", what should it be here?
pub groups: i32,
pub channels: IndexExpr,
pub kernel_size: Array<IndexExpr>,
pub data_layout: TString,
pub kernel_layout: TString,
pub out_layout: TString,
pub out_dtype: DataType,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "Conv2DAttrs"]
Expand All @@ -42,6 +71,7 @@ pub struct Conv2DAttrsNode {
pub data_layout: TString,
pub kernel_layout: TString,
pub out_layout: TString,
pub auto_scheduler_rewritten_layout: TString,
pub out_dtype: DataType,
}

Expand Down Expand Up @@ -138,6 +168,7 @@ pub struct AvgPool2DAttrsNode {
pub pool_size: Array<IndexExpr>,
pub strides: Array<IndexExpr>,
pub padding: Array<IndexExpr>,
pub dilation: Array<IndexExpr>,
pub layout: TString,
pub ceil_mode: bool,
pub count_include_pad: bool,
Expand All @@ -155,3 +186,34 @@ pub struct UpSamplingAttrsNode {
pub method: TString,
pub align_corners: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "DropoutAttrs"]
#[type_key = "relay.attrs.DropoutAttrs"]
pub struct DropoutAttrsNode {
pub base: BaseAttrsNode,
pub rate: f64,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "BatchMatmulAttrs"]
#[type_key = "relay.attrs.BatchMatmulAttrs"]
pub struct BatchMatmulAttrsNode {
pub base: BaseAttrsNode,
pub auto_scheduler_rewritten_layout: TString,
pub out_dtype: DataType,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "LayerNormAttrs"]
#[type_key = "relay.attrs.LayerNormAttrs"]
pub struct LayerNormAttrsNode {
pub base: BaseAttrsNode,
pub axis: i32,
pub epsilon: f64,
pub center: bool,
pub scale: bool,
}
48 changes: 48 additions & 0 deletions rust/tvm/src/ir/relay/attrs/reduce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use crate::ir::attrs::BaseAttrsNode;
use crate::ir::PrimExpr;
use crate::runtime::array::Array;
use tvm_macros::Object;

type IndexExpr = PrimExpr;

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ReduceAttrs"]
#[type_key = "relay.attrs.ReduceAttrs"]
pub struct ReduceAttrsNode {
pub base: BaseAttrsNode,
pub axis: Array<IndexExpr>,
pub keepdims: bool,
pub exclude: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "VarianceAttrs"]
#[type_key = "relay.attrs.ReduceAttrs"]
pub struct VarianceAttrsNode {
pub base: BaseAttrsNode,
pub axis: Array<IndexExpr>,
pub keepdims: bool,
pub exclude: bool,
pub unbiased: bool,
}
59 changes: 58 additions & 1 deletion rust/tvm/src/ir/relay/attrs/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,35 @@
*/

use crate::ir::attrs::BaseAttrsNode;
use crate::ir::relay::TString;
use crate::ir::tir::IntImm;
use crate::ir::PrimExpr;
use crate::runtime::array::Array;
use crate::runtime::ObjectRef;
use tvm_macros::Object;
use tvm_rt::DataType;

type IndexExpr = PrimExpr;

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ClipAttrs"]
#[type_key = "relay.attrs.ClipAttrs"]
pub struct ClipAttrsNode {
pub base: BaseAttrsNode,
pub a_min: f64,
pub a_max: f64,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "CastAttrs"]
#[type_key = "relay.attrs.CastAttrs"]
pub struct CastAttrsNode {
pub base: BaseAttrsNode,
pub dtype: DataType,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ExpandDimsAttrs"]
Expand Down Expand Up @@ -79,5 +101,40 @@ pub struct TransposeAttrsNode {
#[type_key = "relay.attrs.SqueezeAttrs"]
pub struct SqueezeAttrsNode {
pub base: BaseAttrsNode,
pub axis: Array<IndexExpr>,
pub axis: Array<IntImm>,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TakeAttrs"]
#[type_key = "relay.attrs.TakeAttrs"]
pub struct TakeAttrsNode {
pub base: BaseAttrsNode,
pub batch_dims: IntImm,
pub axis: IntImm,
pub mode: TString,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "StackAttrs"]
#[type_key = "relay.attrs.StackAttrs"]
pub struct StackAttrsNode {
pub base: BaseAttrsNode,
pub axis: IntImm,
}

// TODO(@gussmith23) How to support Optional type? This "just works" when values
// are provided for begin/end/strides, but I'm not sure what happens if None is
// passed from the C++ side.
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "StridedSliceAttrs"]
#[type_key = "relay.attrs.StridedSliceAttrs"]
pub struct StridedSliceAttrsNode {
pub base: BaseAttrsNode,
pub begin: Array<IntImm>,
pub end: Array<IntImm>,
pub strides: Array<IntImm>,
pub slice_mode: TString,
}
4 changes: 3 additions & 1 deletion rust/tvm/src/ir/relay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/
use crate::runtime::array::Array;
use crate::runtime::{object::*, IsObjectRef, String as TString};
use crate::runtime::{object::*, IsObjectRef, String as TString, self};

use super::attrs::Attrs;
use super::expr::BaseExprNode;
Expand Down Expand Up @@ -150,6 +150,7 @@ impl Var {
#[type_key = "relay.Call"]
pub struct CallNode {
pub base: ExprNode,
deleter: ObjectRef,
pub op: Expr,
pub args: Array<Expr>,
pub attrs: Attrs,
Expand All @@ -166,6 +167,7 @@ impl Call {
) -> Call {
let node = CallNode {
base: ExprNode::base::<CallNode>(span),
deleter: todo!("Don't know how to construct this"),
op: op,
args: args,
attrs: attrs,
Expand Down

0 comments on commit 3657700

Please sign in to comment.