Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX avgpool1d #1744

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/burn-core/src/nn/pool/avg_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::tensor::Tensor;
use burn_tensor::module::avg_pool1d;

/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
#[derive(Config)]
#[derive(Config, Debug)]
pub struct AvgPool1dConfig {
/// The size of the kernel.
pub kernel_size: usize,
Expand All @@ -20,7 +20,7 @@ pub struct AvgPool1dConfig {
pub padding: PaddingConfig1d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
count_include_pad: bool,
pub count_include_pad: bool,
}

/// Applies a 1D avg pooling over input tensors.
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ represent the corresponding Burn Op.
| [Asinh][9] | ❌ | ❌ |
| [Atan][10] | ❌ | ❌ |
| [Atanh][11] | ❌ | ❌ |
| [AveragePool1d][12] | | ✅ |
| [AveragePool1d][12] | | ✅ |
| [AveragePool2d][12] | ✅ | ✅ |
| [BatchNormalization][14] | ✅ | ✅ |
| [Bernoulli][15] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/cast/cast.onnx")
Expand Down
Binary file not shown.
58 changes: 58 additions & 0 deletions crates/burn-import/onnx-tests/tests/avg_pool1d/avg_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

# used to generate model: avg_pool1d.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.pool1 = nn.AvgPool1d(4, stride=2)

self.pool2 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=True)

self.pool3 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=False)

def forward(self, x1, x2, x3):
y1 = self.pool1(x1)
y2 = self.pool2(x2)
y3 = self.pool3(x3)
return y1, y2, y3


def main():
# Set seed for reproducibility
torch.manual_seed(1)

# Print options
torch.set_printoptions(precision=3)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "avg_pool1d.onnx"
input1 = torch.randn(1, 5, 5, device=device)
torch.onnx.export(model, (input1, input1, input1), file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape: {}".format(input1.shape))
print("Test input data: {}".format(input1))
output1, output2, output3 = model.forward(input1, input1, input1)
print("Test output1 data shape: {}".format(output1.shape))
print("Test output2 data shape: {}".format(output2.shape))
print("Test output3 data shape: {}".format(output3.shape))
print("Test output1: {}".format(output1))
print("Test output2: {}".format(output2))
print("Test output3: {}".format(output3))


if __name__ == '__main__':
main()
48 changes: 48 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include_models!(
add_int,
add,
avg_pool2d,
avg_pool1d,
batch_norm,
cast,
clip_opset16,
Expand Down Expand Up @@ -498,6 +499,53 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn avg_pool1d() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: avg_pool1d::Model<Backend> = avg_pool1d::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 3>::from_floats(
[[
[-1.526, -0.750, -0.654, -1.609, -0.100],
[-0.609, -0.980, -1.609, -0.712, 1.171],
[1.767, -0.095, 0.139, -1.579, -0.321],
[-0.299, 1.879, 0.336, 0.275, 1.716],
[-0.056, 0.911, -1.392, 2.689, -0.111],
]],
&device,
);
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
let expected1 = Data::from([[[-1.135], [-0.978], [0.058], [0.548], [0.538]]]);
let expected2 = Data::from([[
[-0.569, -1.135, -0.591],
[-0.397, -0.978, -0.288],
[0.418, 0.058, -0.440],
[0.395, 0.548, 0.582],
[0.214, 0.538, 0.296],
]]);
let expected3 = Data::from([[
[-1.138, -1.135, -0.788],
[-0.794, -0.978, -0.383],
[0.836, 0.058, -0.587],
[0.790, 0.548, 0.776],
[0.427, 0.538, 0.395],
]]);

let expected_shape1 = Shape::from([1, 5, 1]);
let expected_shape2 = Shape::from([1, 5, 3]);
let expected_shape3 = Shape::from([1, 5, 3]);

assert_eq!(output1.shape(), expected_shape1);
assert_eq!(output2.shape(), expected_shape2);
assert_eq!(output3.shape(), expected_shape3);

output1.to_data().assert_approx_eq(&expected1, 3);
output2.to_data().assert_approx_eq(&expected2, 3);
output3.to_data().assert_approx_eq(&expected3, 3);
}

#[test]
fn avg_pool2d() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
155 changes: 155 additions & 0 deletions crates/burn-import/src/burn/node/avg_pool1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use proc_macro2::TokenStream;
use quote::quote;

use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings};

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};

#[derive(Debug, Clone)]
pub struct AvgPool1dNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub config: AvgPool1dConfig,
}

impl AvgPool1dNode {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
config: AvgPool1dConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
AvgPool1d
},
),
input,
output,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool1dNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.stride.to_tokens();
let padding = self.config.padding.to_tokens();
let count_include_pad = self.config.count_include_pad;

let tokens = quote! {
let #name = AvgPool1dConfig::new(#kernel_size)
.with_stride(#strides)
.with_padding(#padding)
.with_count_include_pad(#count_include_pad)
.init();
};

Some(tokens)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}

fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PaddingConfig1d");
imports.register("burn::nn::pool::AvgPool1d");
imports.register("burn::nn::pool::AvgPool1dConfig");
}

fn into_node(self) -> Node<PS> {
Node::AvgPool1d(self)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(AvgPool1dNode::new(
"avg_pool1d",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::PaddingConfig1d;
use burn::nn::pool::AvgPool1d;
use burn::nn::pool::AvgPool1dConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
avg_pool1d: AvgPool1d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let avg_pool1d = AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid)
.with_count_include_pad(true)
.init();

Self {
avg_pool1d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.avg_pool1d.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
17 changes: 10 additions & 7 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -75,6 +75,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {

#[derive(Debug, Clone)]
pub enum Node<PS: PrecisionSettings> {
AvgPool1d(AvgPool1dNode),
AvgPool2d(AvgPool2dNode),
BatchNorm(BatchNormNode<PS>),
Binary(BinaryNode),
Expand Down Expand Up @@ -103,6 +104,7 @@ macro_rules! match_all {
($self:expr, $func:expr) => {{
#[allow(clippy::redundant_closure_call)]
match $self {
Node::AvgPool1d(node) => $func(node),
Node::AvgPool2d(node) => $func(node),
Node::BatchNorm(node) => $func(node),
Node::Binary(node) => $func(node),
Expand Down Expand Up @@ -141,6 +143,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Node::Binary(binary) => binary.binary_type.as_str(),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod base;

pub(crate) mod avg_pool1d;
pub(crate) mod avg_pool2d;
pub(crate) mod batch_norm;
pub(crate) mod binary;
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
NodeType::Cast => cast_update_outputs(node),
Expand All @@ -38,6 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MatMul => matmul_update_outputs(node),
NodeType::MaxPool1d => same_as_input(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
Expand Down
Loading
Loading