diff --git a/burn-autodiff/src/tests/softmax.rs b/burn-autodiff/src/tests/softmax.rs index e19651384a..d282a4be45 100644 --- a/burn-autodiff/src/tests/softmax.rs +++ b/burn-autodiff/src/tests/softmax.rs @@ -46,4 +46,27 @@ mod tests { .to_data() .assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3); } + + #[test] + fn test_quiet_softmax_grad() { + let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = Tensor::::from_data(data_1).require_grad(); + let tensor_2 = Tensor::::from_data(data_2).require_grad(); + + let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); + let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); + + let grads = tensor_4.backward(); + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3); + } } diff --git a/burn-core/src/nn/attention/mha.rs b/burn-core/src/nn/attention/mha.rs index 516166d656..16ed02508e 100644 --- a/burn-core/src/nn/attention/mha.rs +++ b/burn-core/src/nn/attention/mha.rs @@ -25,6 +25,14 @@ pub struct MultiHeadAttentionConfig { /// A value too low might result in NaN. #[config(default = -1.0e4)] min_float: f64, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + quiet_softmax: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" @@ -51,6 +59,7 @@ pub struct MultiHeadAttention { n_heads: usize, d_k: usize, min_float: f64, + quiet_softmax: bool, } /// [Multihead attention](MultiHeadAttention) forward pass input argument. @@ -82,6 +91,7 @@ impl MultiHeadAttentionConfig { n_heads: self.n_heads, d_k: self.d_model / self.n_heads, min_float: self.min_float, + quiet_softmax: self.quiet_softmax, } } @@ -105,6 +115,7 @@ impl MultiHeadAttentionConfig { n_heads: self.n_heads, d_k: self.d_model / self.n_heads, min_float: self.min_float, + quiet_softmax: self.quiet_softmax, } } } @@ -249,7 +260,11 @@ impl MultiHeadAttention { ); } - activation::softmax(attn_scores, 3) + if self.quiet_softmax { + activation::quiet_softmax(attn_scores, 3) + } else { + activation::softmax(attn_scores, 3) + } } fn attention_linear(&self, x: Tensor, linear: &nn::Linear) -> Tensor { diff --git a/burn-core/src/nn/loss/binary_cross_entropy.rs b/burn-core/src/nn/loss/binary_cross_entropy.rs index 506e34ab71..7cb85cec66 100644 --- a/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -45,10 +45,10 @@ impl BinaryCrossEntropyLossConfig { fn assertions(&self) { if let Some(alpha) = self.smoothing { assert!( - (0.0..=1.).contains(&alpha), - "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", - alpha - ); + (0.0..=1.).contains(&alpha), + "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", + alpha + ); }; if let Some(weights) = self.weights.as_ref() { assert!( diff --git a/burn-core/src/nn/loss/cross_entropy.rs b/burn-core/src/nn/loss/cross_entropy.rs index 99e6fd5cc6..41c5a0b338 100644 --- a/burn-core/src/nn/loss/cross_entropy.rs +++ b/burn-core/src/nn/loss/cross_entropy.rs @@ -53,10 +53,10 @@ impl CrossEntropyLossConfig { fn assertions(&self) { if let Some(alpha) = self.smoothing { assert!( - (0.0..=1.).contains(&alpha), - "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", - alpha - ); + (0.0..=1.).contains(&alpha), + "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}", + alpha + ); }; if let Some(weights) = self.weights.as_ref() { assert!( diff --git a/burn-core/src/nn/transformer/decoder.rs b/burn-core/src/nn/transformer/decoder.rs index db5afd5b74..5033a86692 100644 --- a/burn-core/src/nn/transformer/decoder.rs +++ b/burn-core/src/nn/transformer/decoder.rs @@ -34,6 +34,14 @@ pub struct TransformerDecoderConfig { /// Layer norm will be applied first instead of after the other modules. #[config(default = false)] pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" @@ -186,11 +194,13 @@ impl TransformerDecoderLayer { let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init(); let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init(); let norm_1 = LayerNormConfig::new(config.d_model).init(); let norm_2 = LayerNormConfig::new(config.d_model).init(); @@ -219,10 +229,12 @@ impl TransformerDecoderLayer { let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init_with(record.self_attn); let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init_with(record.cross_attn); let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); diff --git a/burn-core/src/nn/transformer/encoder.rs b/burn-core/src/nn/transformer/encoder.rs index 3d5c17d601..7e7b4fa7e3 100644 --- a/burn-core/src/nn/transformer/encoder.rs +++ b/burn-core/src/nn/transformer/encoder.rs @@ -34,6 +34,14 @@ pub struct TransformerEncoderConfig { /// Layer norm will be applied first instead of after the other modules. #[config(default = false)] pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}" @@ -175,6 +183,7 @@ impl TransformerEncoderLayer { let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init_with(record.mha); let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); @@ -197,6 +206,7 @@ impl TransformerEncoderLayer { let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init(); let norm_1 = LayerNormConfig::new(config.d_model).init(); let norm_2 = LayerNormConfig::new(config.d_model).init(); diff --git a/burn-derive/src/config/analyzer_enum.rs b/burn-derive/src/config/analyzer_enum.rs index e926f4c502..2f7e2347b1 100644 --- a/burn-derive/src/config/analyzer_enum.rs +++ b/burn-derive/src/config/analyzer_enum.rs @@ -72,11 +72,11 @@ impl ConfigEnumAnalyzer { fn gen_serialize_fn(&self) -> TokenStream { let enum_name = self.serde_enum_ident(); let variants = self.data.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let (variant_input, variant_output) = self.gen_variant_field(variant); + let variant_name = &variant.ident; + let (variant_input, variant_output) = self.gen_variant_field(variant); - quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output } - }); + quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output } + }); let name = &self.name; quote! { @@ -97,11 +97,11 @@ impl ConfigEnumAnalyzer { fn gen_deserialize_fn(&self) -> TokenStream { let enum_name = self.serde_enum_ident(); let variants = self.data.variants.iter().map(|variant| { - let variant_name = &variant.ident; - let (variant_input, variant_output) = self.gen_variant_field(variant); + let variant_name = &variant.ident; + let (variant_input, variant_output) = self.gen_variant_field(variant); - quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output } - }); + quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output } + }); let name = &self.name; quote! { diff --git a/burn-derive/src/record/codegen_struct.rs b/burn-derive/src/record/codegen_struct.rs index fcefaaad48..331b38b308 100644 --- a/burn-derive/src/record/codegen_struct.rs +++ b/burn-derive/src/record/codegen_struct.rs @@ -23,9 +23,9 @@ impl RecordItemCodegen for StructRecordItemCodegen { /// Field to be serialized. pub #name: <#ty as burn::record::Record>::Item, }); - bounds.extend(quote!{ - <#ty as burn::record::Record>::Item: serde::Serialize + serde::de::DeserializeOwned, - }); + bounds.extend(quote! { + <#ty as burn::record::Record>::Item: serde::Serialize + serde::de::DeserializeOwned, + }); } let bound = bounds.to_string(); diff --git a/burn-import/src/burn/node/binary.rs b/burn-import/src/burn/node/binary.rs index 2e237339a2..f874fc8ab5 100644 --- a/burn-import/src/burn/node/binary.rs +++ b/burn-import/src/burn/node/binary.rs @@ -167,25 +167,25 @@ mod tests { use crate::burn::{ScalarKind, ScalarType, TensorType}; macro_rules! test_binary_operator_on_tensors { - ($operator:ident) => {{ - one_node_graph( - BinaryNode::$operator( - Type::Tensor(TensorType::new_float("tensor1", 4)), - Type::Tensor(TensorType::new_float("tensor2", 4)), - Type::Tensor(TensorType::new_float("tensor3", 4)), - ), - quote! { - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.$operator(tensor2); + ($operator:ident) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Type::Tensor(TensorType::new_float("tensor3", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.$operator(tensor2); - tensor3 - } - }, - vec!["tensor1".to_string(), "tensor2".to_string()], - vec!["tensor3".to_string()], - ); - }}; - } + tensor3 + } + }, + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + }}; + } macro_rules! test_binary_operator_on_tensor_and_scalar { ($operator:ident, $burn_operator:ident) => {{ diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index bb73a02d43..7b2b9f96b9 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -354,7 +354,10 @@ where for i in 0..D - 1 { if shape_tensor.dims[i] != shape_indices.dims[i] { - panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims); + panic!( + "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", + shape_tensor.dims, shape_indices.dims + ); } batch_size *= shape_indices.dims[i]; } diff --git a/burn-tensor/src/tensor/activation/base.rs b/burn-tensor/src/tensor/activation/base.rs index 5fa502422c..944de1cab6 100644 --- a/burn-tensor/src/tensor/activation/base.rs +++ b/burn-tensor/src/tensor/activation/base.rs @@ -31,6 +31,26 @@ pub fn softmax(tensor: Tensor, dim: usize) -> tensor.div(tensor_tmp) } +/// Applies the "quiet softmax" function on the input tensor along the given dimension. +/// This function is similar to the softmax function, but it allows for "no selection", e.g., +/// all outputs can tend to zero. +/// +/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]` +/// +/// # Notes +/// +/// The dimension argument `dim` specifies the dimension along which the function will be computed. +/// It must in the range of `0` and `D-1`. +pub fn quiet_softmax(tensor: Tensor, dim: usize) -> Tensor { + check!(TensorCheck::dim_ops::("softmax", dim)); + + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor = tensor.exp(); + let tensor_tmp = tensor.clone().sum_dim(dim); + + tensor.div(tensor_tmp + 1) +} + /// Applies the log softmax function on the input tensor along the given dimension. /// /// `log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))` diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index 0cecc57f9c..713ffc484a 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -87,13 +87,16 @@ impl TensorCheck { let mut check = Self::Ok; if original.num_elements() != target.num_elements() { - check = check.register("Reshape", TensorError::new( - "The given shape doesn't have the same number of elements as the current tensor.", - ) - .details(format!( - "Current shape: {:?}, target shape: {:?}.", - original.dims, target.dims - ))); + check = check.register( + "Reshape", + TensorError::new( + "The given shape doesn't have the same number of elements as the current tensor.", + ) + .details(format!( + "Current shape: {:?}, target shape: {:?}.", + original.dims, target.dims + )), + ); } check @@ -251,8 +254,8 @@ impl TensorCheck { check = check.register( "Matmul", TensorError::new(format!( - "The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}." - )) + "The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}." + )) .details(format!( "Lhs shape {:?}, rhs shape {:?}.", shape_lhs.dims, shape_rhs.dims @@ -298,15 +301,16 @@ impl TensorCheck { if shape_reference != shape { return check.register( - "Cat", - TensorError::new("Can't concatenate tensors with different shapes, except for the provided dimension").details( - format!( - "Provided dimension ({}), tensors shapes: {:?}", - dim, - tensors.iter().map(Tensor::shape).collect::>() - ), - ), - ); + "Cat", + TensorError::new( + "Can't concatenate tensors with different shapes, except for the provided dimension", + ) + .details(format!( + "Provided dimension ({}), tensors shapes: {:?}", + dim, + tensors.iter().map(Tensor::shape).collect::>() + )), + ); } } @@ -337,18 +341,16 @@ impl TensorCheck { if range.end > d_tensor { check = check.register( - "Slice", - TensorError::new("The provided ranges array has a range that exceeds the current tensor size.") - .details(format!( - "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ + "Slice", + TensorError::new( + "The provided ranges array has a range that exceeds the current tensor size.", + ) + .details(format!( + "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ Tensor shape {:?}, provided ranges {:?}.", - range.start, - range.end, - d_tensor, - i, - shape.dims, - ranges, - ))); + range.start, range.end, d_tensor, i, shape.dims, ranges, + )), + ); } if range.start >= range.end { @@ -378,13 +380,16 @@ impl TensorCheck { let mut check = Self::Ok; if D1 < D2 { - check = check.register("Slice Assign", - TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.") - .details( - format!( - "The ranges array must be smaller or equal to the tensor number of dimensions. \ + check = check.register( + "Slice Assign", + TensorError::new( + "The provided ranges array has a higher number of dimensions than the current tensor.", + ) + .details(format!( + "The ranges array must be smaller or equal to the tensor number of dimensions. \ Tensor number of dimensions: {D1}, ranges array length {D2}." - ))); + )), + ); } for i in 0..usize::min(D1, D2) { @@ -394,19 +399,16 @@ impl TensorCheck { if range.end > d_tensor { check = check.register( - "Range Assign", - TensorError::new("The provided ranges array has a range that exceeds the current tensor size.") - .details(format!( - "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ + "Range Assign", + TensorError::new( + "The provided ranges array has a range that exceeds the current tensor size.", + ) + .details(format!( + "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.", - range.start, - range.end, - d_tensor, - i, - shape.dims, - shape_value.dims, - ranges, - ))); + range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges, + )), + ); } if range.end - range.start != d_tensor_value { @@ -596,17 +598,16 @@ impl TensorCheck { continue; } - check = check.register(ops, - TensorError::new("The provided tensors have incompatible shapes.") - .details(format!( - "Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \ + check = check.register( + ops, + TensorError::new("The provided tensors have incompatible shapes.").details( + format!( + "Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \ Lhs tensor shape {:?}, Rhs tensor shape {:?}.", - i, - d_lhs, - d_rhs, - lhs.dims, - rhs.dims, - ))); + i, d_lhs, d_rhs, lhs.dims, rhs.dims, + ), + ), + ); } } diff --git a/burn-tensor/src/tensor/data.rs b/burn-tensor/src/tensor/data.rs index 1794e03cd1..755c4af299 100644 --- a/burn-tensor/src/tensor/data.rs +++ b/burn-tensor/src/tensor/data.rs @@ -320,10 +320,9 @@ impl + Clone + core::fmt::Debug + PartialEq, const D: usize> Data tolerance { // Only print the first 5 different values. if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}" - ) - .as_str(); + message += + format!("\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}") + .as_str(); } num_diff += 1; } diff --git a/burn-tensor/src/tests/activation/quiet_softmax.rs b/burn-tensor/src/tests/activation/quiet_softmax.rs new file mode 100644 index 0000000000..7a3733db70 --- /dev/null +++ b/burn-tensor/src/tests/activation/quiet_softmax.rs @@ -0,0 +1,16 @@ +#[burn_tensor_testgen::testgen(quiet_softmax)] +mod tests { + use super::*; + use burn_tensor::{activation, Data, Tensor}; + + #[test] + fn test_quiet_softmax_d2() { + let data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = activation::quiet_softmax(tensor, 1).to_data(); + + let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); + data_actual.assert_approx_eq(&data_expected, 4); + } +}