Skip to content

Commit

Permalink
Feat/candle/module ops (tracel-ai#725)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Aug 30, 2023
1 parent aafceef commit 760c9e1
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 74 deletions.
77 changes: 39 additions & 38 deletions burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ mod tests {
type TestADTensor<const D: usize, K> = burn_tensor::Tensor<TestADBackend, D, K>;

// test activation
// burn_tensor::testgen_gelu!();
// burn_tensor::testgen_relu!();
// burn_tensor::testgen_softmax!();
// burn_tensor::testgen_sigmoid!();
// burn_tensor::testgen_silu!();
burn_tensor::testgen_gelu!();
burn_tensor::testgen_relu!();
burn_tensor::testgen_softmax!();
burn_tensor::testgen_sigmoid!();
burn_tensor::testgen_silu!();

// test module
// burn_tensor::testgen_module_forward!();
// burn_tensor::testgen_module_conv1d!();
burn_tensor::testgen_module_forward!();
burn_tensor::testgen_module_conv1d!();
// burn_tensor::testgen_module_conv2d!();
// burn_tensor::testgen_module_conv_transpose1d!();
// burn_tensor::testgen_module_conv_transpose2d!();
Expand Down Expand Up @@ -72,7 +72,7 @@ mod tests {
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
// burn_tensor::testgen_powf!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_random!();
// burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
Expand All @@ -87,14 +87,15 @@ mod tests {
burn_tensor::testgen_transpose!();

// test stats
// burn_tensor::testgen_stats!();
burn_tensor::testgen_var!();
burn_tensor::testgen_display!();

// Behavior
// burn_autodiff::testgen_ad_broadcast!();

// Activation
// burn_autodiff::testgen_ad_relu!();
// burn_autodiff::testgen_ad_gelu!();
burn_autodiff::testgen_ad_relu!();
burn_autodiff::testgen_ad_gelu!();

// Modules
// burn_autodiff::testgen_ad_conv1d!();
Expand All @@ -107,36 +108,36 @@ mod tests {
// burn_autodiff::testgen_ad_avg_pool2d!();
// burn_autodiff::testgen_ad_adaptive_avg_pool1d!();
// burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
// burn_autodiff::testgen_module_backward!();
burn_autodiff::testgen_module_backward!();

// Tensor
// burn_autodiff::testgen_ad_complex!();
// burn_autodiff::testgen_ad_multithread!();
// burn_autodiff::testgen_ad_add!();
// burn_autodiff::testgen_ad_aggregation!();
// burn_autodiff::testgen_ad_maxmin!();
burn_autodiff::testgen_ad_complex!();
burn_autodiff::testgen_ad_multithread!();
burn_autodiff::testgen_ad_add!();
burn_autodiff::testgen_ad_aggregation!();
burn_autodiff::testgen_ad_maxmin!();
// burn_autodiff::testgen_ad_cat!();
// burn_autodiff::testgen_ad_cos!();
// burn_autodiff::testgen_ad_cross_entropy_loss!();
// burn_autodiff::testgen_ad_div!();
burn_autodiff::testgen_ad_cos!();
burn_autodiff::testgen_ad_cross_entropy_loss!();
burn_autodiff::testgen_ad_div!();
// burn_autodiff::testgen_ad_erf!();
// burn_autodiff::testgen_ad_exp!();
burn_autodiff::testgen_ad_exp!();
// burn_autodiff::testgen_ad_slice!();
// burn_autodiff::testgen_ad_gather_scatter!();
// burn_autodiff::testgen_ad_select!();
// burn_autodiff::testgen_ad_log!();
// burn_autodiff::testgen_ad_log1p!();
// burn_autodiff::testgen_ad_mask!();
// burn_autodiff::testgen_ad_matmul!();
// burn_autodiff::testgen_ad_mul!();
// burn_autodiff::testgen_ad_neg!();
// burn_autodiff::testgen_ad_powf!();
// burn_autodiff::testgen_ad_reshape!();
// burn_autodiff::testgen_ad_sin!();
// burn_autodiff::testgen_ad_softmax!();
// burn_autodiff::testgen_ad_sqrt!();
// burn_autodiff::testgen_ad_abs!();
// burn_autodiff::testgen_ad_sub!();
// burn_autodiff::testgen_ad_tanh!();
// burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_gather_scatter!();
burn_autodiff::testgen_ad_select!();
burn_autodiff::testgen_ad_log!();
burn_autodiff::testgen_ad_log1p!();
burn_autodiff::testgen_ad_mask!();
burn_autodiff::testgen_ad_matmul!();
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
burn_autodiff::testgen_ad_sqrt!();
burn_autodiff::testgen_ad_abs!();
burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_transpose!();
}
11 changes: 10 additions & 1 deletion burn-candle/src/ops/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@ use burn_tensor::ops::ActivationOps;

use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
CandleBackend,
tensor, CandleBackend, CandleTensor,
};

use super::base::FloatTensor;

impl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<CandleBackend<F, I>>
for CandleBackend<F, I>
{
fn gelu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.gelu().unwrap())
}

fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.relu().unwrap())
}
}
3 changes: 1 addition & 2 deletions burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,5 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
ranges: [std::ops::Range<usize>; D2],
value: CandleTensor<E, D1>,
) -> CandleTensor<E, D1> {
// TODO: not trivial, because no view_ like in torch
todo!()
panic!("slice_assign not supported by Candle")
}
149 changes: 137 additions & 12 deletions burn-candle/src/ops/module.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,90 @@
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
use burn_tensor::{
ops::{ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps},
Shape,
};
use candle_core::ToUsize2;

use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
CandleBackend,
ops::base::reshape,
CandleBackend, CandleTensor,
};

use super::base::{FloatTensor, IntTensor};

impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
for CandleBackend<F, I>
{
fn conv1d(
x: FloatTensor<Self, 3>,
weight: FloatTensor<Self, 3>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<1>,
) -> FloatTensor<Self, 3> {
let conv = x
.tensor
.conv1d(
&weight.tensor,
options.padding[0],
options.stride[0],
options.dilation[0],
options.groups,
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv
.broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
.unwrap(),
None => conv,
})
}

fn conv2d(
x: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<2>,
) -> FloatTensor<Self, 4> {
todo!()
assert!(
options.dilation[0] == options.dilation[1]
&& options.padding[0] == options.padding[1]
&& options.stride[0] == options.stride[1],
"Candle does not support per dimension options in convolutions"
);
let conv = x
.tensor
.conv2d(
&weight.tensor,
options.padding[0],
options.stride[0],
options.dilation[0],
options.groups,
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv
.broadcast_add(
&bias
.tensor
.unsqueeze(0)
.unwrap()
.unsqueeze(2)
.unwrap()
.unsqueeze(3)
.unwrap(),
)
.unwrap(),
None => conv,
})
}

fn conv_transpose1d(
x: FloatTensor<Self, 3>,
weight: FloatTensor<Self, 3>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> {
panic!("Candle does not support conv_transpose1d")
}

fn conv_transpose2d(
Expand All @@ -27,7 +93,42 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self, 4> {
todo!()
assert!(
options.dilation[0] == options.dilation[1]
&& options.padding[0] == options.padding[1]
&& options.padding_out[0] == options.padding_out[1]
&& options.stride[0] == options.stride[1],
"Candle does not support per dimension options in transposed convolutions"
);
assert!(
options.groups == 1,
"Candle does not support groups in transposed convolutions"
);
let conv_transpose = x
.tensor
.conv_transpose2d(
&weight.tensor,
options.padding[0],
options.padding_out[0],
options.stride[0],
options.dilation[0],
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv_transpose
.broadcast_add(
&bias
.tensor
.unsqueeze(0)
.unwrap()
.unsqueeze(2)
.unwrap()
.unsqueeze(3)
.unwrap(),
)
.unwrap(),
None => conv_transpose,
})
}

fn avg_pool2d(
Expand All @@ -37,7 +138,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self, 4> {
todo!()
assert!(
padding[0] == 0 && padding[1] == 0,
"Candle does not support padding in pooling"
);
assert!(
count_include_pad,
"Candle does not support excluding pad count in pooling"
);
CandleTensor::new(
x.tensor
.avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
.unwrap(),
)
}

fn avg_pool2d_backward(
Expand All @@ -48,7 +161,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self, 4> {
todo!()
panic!("avg_pool2d_backward is not supported by Candle")
}

fn max_pool2d(
Expand All @@ -58,7 +171,19 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
padding: [usize; 2],
dilation: [usize; 2],
) -> FloatTensor<Self, 4> {
todo!()
assert!(
padding[0] == 0 && padding[1] == 0,
"Candle does not support padding in pooling"
);
assert!(
dilation[0] == 1 && dilation[1] == 1,
"Candle does not support dilation in pooling"
);
CandleTensor::new(
x.tensor
.max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
.unwrap(),
)
}

fn max_pool2d_with_indices(
Expand All @@ -68,7 +193,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<CandleBackend<F, I>> {
todo!()
panic!("max_pool2d_with_indices is not supported by Candle")
}

fn max_pool2d_with_indices_backward(
Expand All @@ -80,20 +205,20 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
output_grad: FloatTensor<Self, 4>,
indices: IntTensor<Self, 4>,
) -> MaxPool2dBackward<CandleBackend<F, I>> {
todo!()
panic!("max_pool2d_with_indices_backward is not supported by Candle")
}

fn adaptive_avg_pool2d(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
) -> FloatTensor<Self, 4> {
todo!()
panic!("adaptive_avg_pool2 is not supported by Candle")
}

fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
) -> FloatTensor<Self, 4> {
todo!()
panic!("adaptive_avg_pool2d_backward is not supported by Candle")
}
}
7 changes: 6 additions & 1 deletion burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,12 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
}

fn powf<const D: usize>(tensor: FloatTensor<Self, D>, value: f32) -> FloatTensor<Self, D> {
panic!("powf not supported by Candle")
CandleTensor::new(
(tensor.tensor.log().unwrap() * value.elem::<f64>())
.unwrap()
.exp()
.unwrap(),
)
}

fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
Expand Down
2 changes: 2 additions & 0 deletions burn-tch/src/ops/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}

fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(
|mut tensor| tensor.gelu_("none"),
|tensor| tensor.gelu("none"),
)
}

fn gelu_backward<const D: usize>(
tensor: TchTensor<E, D>,
grad: TchTensor<E, D>,
Expand Down
Loading

0 comments on commit 760c9e1

Please sign in to comment.