Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Quiet Softmax (Attention Is Off By One) #692

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions burn-autodiff/src/tests/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,27 @@ mod tests {
.to_data()
.assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3);
}

#[test]
fn test_quiet_softmax_grad() {
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());

let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3);
}
}
17 changes: 16 additions & 1 deletion burn-core/src/nn/attention/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ pub struct MultiHeadAttentionConfig {
/// A value too low might result in NaN.
#[config(default = -1.0e4)]
min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
Expand All @@ -51,6 +59,7 @@ pub struct MultiHeadAttention<B: Backend> {
n_heads: usize,
d_k: usize,
min_float: f64,
quiet_softmax: bool,
}

/// [Multihead attention](MultiHeadAttention) forward pass input argument.
Expand Down Expand Up @@ -82,6 +91,7 @@ impl MultiHeadAttentionConfig {
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}

Expand All @@ -105,6 +115,7 @@ impl MultiHeadAttentionConfig {
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}
}
Expand Down Expand Up @@ -249,7 +260,11 @@ impl<B: Backend> MultiHeadAttention<B> {
);
}

activation::softmax(attn_scores, 3)
if self.quiet_softmax {
activation::quiet_softmax(attn_scores, 3)
} else {
activation::softmax(attn_scores, 3)
}
}

fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
Expand Down
8 changes: 4 additions & 4 deletions burn-core/src/nn/loss/binary_cross_entropy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ impl BinaryCrossEntropyLossConfig {
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(
Expand Down
8 changes: 4 additions & 4 deletions burn-core/src/nn/loss/cross_entropy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ impl CrossEntropyLossConfig {
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(
Expand Down
12 changes: 12 additions & 0 deletions burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ pub struct TransformerDecoderConfig {
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
Expand Down Expand Up @@ -186,11 +194,13 @@ impl<B: Backend> TransformerDecoderLayer<B> {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();

let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let norm_1 = LayerNormConfig::new(config.d_model).init();
let norm_2 = LayerNormConfig::new(config.d_model).init();
Expand Down Expand Up @@ -219,10 +229,12 @@ impl<B: Backend> TransformerDecoderLayer<B> {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.self_attn);
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.cross_attn);
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);
Expand Down
10 changes: 10 additions & 0 deletions burn-core/src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ pub struct TransformerEncoderConfig {
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
Expand Down Expand Up @@ -175,6 +183,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.mha);
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);
Expand All @@ -197,6 +206,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let norm_1 = LayerNormConfig::new(config.d_model).init();
let norm_2 = LayerNormConfig::new(config.d_model).init();
Expand Down
16 changes: 8 additions & 8 deletions burn-derive/src/config/analyzer_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ impl ConfigEnumAnalyzer {
fn gen_serialize_fn(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);

quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
});
quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
});
Comment on lines +75 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That formatting seems odd to me and somehow fmt can't update it 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you like me to do to resolve this issue and merge the PR? cargo fmt --all results in a bit identical repository. Is the issue these two lines of whitespace? Should I modify them manually? You mention no changes should be made to burn-derive. Would like to get this closed out, apologies on the silliness of these problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just reset all the changes under the burn-derive directory to origin/main. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's due to the quote! macro. rust-lang/rustfmt#8

If you comment out quote!, format, and uncomment, it'll do the right thing.

Copy link
Contributor

@AlexErrant AlexErrant Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaaaaactshully it's due to the default max_width = 100. If you add a file called rustfmt.toml and in it is max_width = 110 then format, it works. Related. I'll open another PR about this tomorrow since it causes changes elsewhere and deserves its own discussion.

let name = &self.name;

quote! {
Expand All @@ -97,11 +97,11 @@ impl ConfigEnumAnalyzer {
fn gen_deserialize_fn(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);

quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output }
});
quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output }
});
let name = &self.name;

quote! {
Expand Down
6 changes: 3 additions & 3 deletions burn-derive/src/record/codegen_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ impl RecordItemCodegen for StructRecordItemCodegen {
/// Field to be serialized.
pub #name: <#ty as burn::record::Record>::Item<S>,
});
bounds.extend(quote!{
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
});
bounds.extend(quote! {
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
});
}
let bound = bounds.to_string();

Expand Down
36 changes: 18 additions & 18 deletions burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,25 @@ mod tests {
use crate::burn::{ScalarKind, ScalarType, TensorType};

macro_rules! test_binary_operator_on_tensors {
($operator:ident) => {{
one_node_graph(
BinaryNode::$operator(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Type::Tensor(TensorType::new_float("tensor3", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.$operator(tensor2);
($operator:ident) => {{
one_node_graph(
BinaryNode::$operator(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Type::Tensor(TensorType::new_float("tensor3", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.$operator(tensor2);

tensor3
}
},
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
}};
}
tensor3
}
},
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
}};
}

macro_rules! test_binary_operator_on_tensor_and_scalar {
($operator:ident, $burn_operator:ident) => {{
Expand Down
5 changes: 4 additions & 1 deletion burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,10 @@ where

for i in 0..D - 1 {
if shape_tensor.dims[i] != shape_indices.dims[i] {
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims);
panic!(
"Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}",
shape_tensor.dims, shape_indices.dims
);
}
batch_size *= shape_indices.dims[i];
}
Expand Down
20 changes: 20 additions & 0 deletions burn-tensor/src/tensor/activation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) ->
tensor.div(tensor_tmp)
}

/// Applies the "quiet softmax" function on the input tensor along the given dimension.
/// This function is similar to the softmax function, but it allows for "no selection", e.g.,
/// all outputs can tend to zero.
///
/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`
///
/// # Notes
///
/// The dimension argument `dim` specifies the dimension along which the function will be computed.
/// It must in the range of `0` and `D-1`.
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
wbrickner marked this conversation as resolved.
Show resolved Hide resolved
check!(TensorCheck::dim_ops::<D>("softmax", dim));

let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor = tensor.exp();
let tensor_tmp = tensor.clone().sum_dim(dim);

tensor.div(tensor_tmp + 1)
}

/// Applies the log softmax function on the input tensor along the given dimension.
///
/// `log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`
Expand Down
Loading