diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 4d867be90f5a7..c3c58e54517cc 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -717,7 +717,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; Array padding; - std::string layout; + tvm::String layout; bool ceil_mode; bool count_include_pad; @@ -977,8 +977,8 @@ struct FIFOBufferAttrs : public tvm::AttrsNode { struct UpSamplingAttrs : public tvm::AttrsNode { double scale_h; double scale_w; - std::string layout; - std::string method; + tvm::String layout; + tvm::String method; bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 7ecd92febc228..f0137fa3cbccb 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -106,3 +106,39 @@ pub struct BatchNormAttrsNode { pub center: bool, pub scale: bool, } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "LeakyReluAttrs"] +#[type_key = "relay.attrs.LeakyReluAttrs"] +pub struct LeakyReluAttrsNode { + pub base: BaseAttrsNode, + pub alpha: f64, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "AvgPool2DAttrs"] +#[type_key = "relay.attrs.AvgPool2DAttrs"] +pub struct AvgPool2DAttrsNode { + pub base: BaseAttrsNode, + pub pool_size: Array, + pub strides: Array, + pub padding: Array, + pub layout: TString, + pub ceil_mode: bool, + pub count_include_pad: bool, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "UpSamplingAttrs"] +#[type_key = "relay.attrs.UpSamplingAttrs"] +pub struct UpSamplingAttrsNode { + pub base: BaseAttrsNode, + pub scale_h: f64, + pub scale_w: f64, + pub layout: TString, + pub method: TString, + pub align_corners: bool, +} diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index c459f96b2d2fa..b5f7c2047d621 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -18,8 +18,13 @@ */ use crate::ir::attrs::BaseAttrsNode; +use crate::ir::PrimExpr; +use crate::runtime::array::Array; +use crate::runtime::ObjectRef; use tvm_macros::Object; +type IndexExpr = PrimExpr; + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "ExpandDimsAttrs"] @@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode { pub axis: i32, pub num_newaxis: i32, } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ConcatenateAttrs"] +#[type_key = "relay.attrs.ConcatenateAttrs"] +pub struct ConcatenateAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ReshapeAttrs"] +#[type_key = "relay.attrs.ReshapeAttrs"] +pub struct ReshapeAttrsNode { + pub base: BaseAttrsNode, + pub newshape: Array, + pub reverse: bool, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "SplitAttrs"] +#[type_key = "relay.attrs.SplitAttrs"] +pub struct SplitAttrsNode { + pub base: BaseAttrsNode, + pub indices_or_sections: ObjectRef, + pub axis: i32, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "TransposeAttrs"] +#[type_key = "relay.attrs.TransposeAttrs"] +pub struct TransposeAttrsNode { + pub base: BaseAttrsNode, + pub axes: Array, +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "SqueezeAttrs"] +#[type_key = "relay.attrs.SqueezeAttrs"] +pub struct SqueezeAttrsNode { + pub base: BaseAttrsNode, + pub axis: Array, +}