Skip to content

Commit

Permalink
[Rust] More Rust bindings for Attrs (apache#7082)
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart authored and tkonolige committed Jan 11, 2021
1 parent 44b018b commit 124ebc6
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
tvm::String layout;
bool ceil_mode;
bool count_include_pad;

Expand Down Expand Up @@ -977,8 +977,8 @@ struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
double scale_h;
double scale_w;
std::string layout;
std::string method;
tvm::String layout;
tvm::String method;
bool align_corners;

TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
Expand Down
36 changes: 36 additions & 0 deletions rust/tvm/src/ir/relay/attrs/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,39 @@ pub struct BatchNormAttrsNode {
pub center: bool,
pub scale: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "LeakyReluAttrs"]
#[type_key = "relay.attrs.LeakyReluAttrs"]
pub struct LeakyReluAttrsNode {
pub base: BaseAttrsNode,
pub alpha: f64,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "AvgPool2DAttrs"]
#[type_key = "relay.attrs.AvgPool2DAttrs"]
pub struct AvgPool2DAttrsNode {
pub base: BaseAttrsNode,
pub pool_size: Array<IndexExpr>,
pub strides: Array<IndexExpr>,
pub padding: Array<IndexExpr>,
pub layout: TString,
pub ceil_mode: bool,
pub count_include_pad: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "UpSamplingAttrs"]
#[type_key = "relay.attrs.UpSamplingAttrs"]
pub struct UpSamplingAttrsNode {
pub base: BaseAttrsNode,
pub scale_h: f64,
pub scale_w: f64,
pub layout: TString,
pub method: TString,
pub align_corners: bool,
}
52 changes: 52 additions & 0 deletions rust/tvm/src/ir/relay/attrs/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
*/

use crate::ir::attrs::BaseAttrsNode;
use crate::ir::PrimExpr;
use crate::runtime::array::Array;
use crate::runtime::ObjectRef;
use tvm_macros::Object;

type IndexExpr = PrimExpr;

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ExpandDimsAttrs"]
Expand All @@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode {
pub axis: i32,
pub num_newaxis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ConcatenateAttrs"]
#[type_key = "relay.attrs.ConcatenateAttrs"]
pub struct ConcatenateAttrsNode {
pub base: BaseAttrsNode,
pub axis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ReshapeAttrs"]
#[type_key = "relay.attrs.ReshapeAttrs"]
pub struct ReshapeAttrsNode {
pub base: BaseAttrsNode,
pub newshape: Array<IndexExpr>,
pub reverse: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "SplitAttrs"]
#[type_key = "relay.attrs.SplitAttrs"]
pub struct SplitAttrsNode {
pub base: BaseAttrsNode,
pub indices_or_sections: ObjectRef,
pub axis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TransposeAttrs"]
#[type_key = "relay.attrs.TransposeAttrs"]
pub struct TransposeAttrsNode {
pub base: BaseAttrsNode,
pub axes: Array<IndexExpr>,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "SqueezeAttrs"]
#[type_key = "relay.attrs.SqueezeAttrs"]
pub struct SqueezeAttrsNode {
pub base: BaseAttrsNode,
pub axis: Array<IndexExpr>,
}

0 comments on commit 124ebc6

Please sign in to comment.