Skip to content

Commit

Permalink
Added quiet softmax activation + config option for attention related …
Browse files Browse the repository at this point in the history
…modules
  • Loading branch information
wbrickner committed Aug 24, 2023
1 parent 836fcc7 commit 17ce603
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
17 changes: 16 additions & 1 deletion burn-core/src/nn/attention/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ pub struct MultiHeadAttentionConfig {
/// A value too low might result in NaN.
#[config(default = -1.0e4)]
min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
quiet_softmax: bool,
}

/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
Expand All @@ -45,6 +53,7 @@ pub struct MultiHeadAttention<B: Backend> {
n_heads: usize,
d_k: usize,
min_float: f64,
quiet_softmax: bool,
}

/// [Multihead attention](MultiHeadAttention) forward pass input argument.
Expand Down Expand Up @@ -72,6 +81,7 @@ impl MultiHeadAttentionConfig {
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}

Expand All @@ -95,6 +105,7 @@ impl MultiHeadAttentionConfig {
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}
}
Expand Down Expand Up @@ -239,7 +250,11 @@ impl<B: Backend> MultiHeadAttention<B> {
);
}

activation::softmax(attn_scores, 3)
if self.quiet_softmax {
activation::quiet_softmax(attn_scores, 3)
} else {
activation::softmax(attn_scores, 3)
}
}

fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
Expand Down
10 changes: 10 additions & 0 deletions burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ pub struct TransformerDecoderConfig {
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
}

/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
Expand Down Expand Up @@ -180,10 +188,12 @@ impl<B: Backend> TransformerDecoderLayer<B> {
fn new(config: &TransformerDecoderConfig) -> Self {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();

let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let norm_1 = LayerNormConfig::new(config.d_model).init();
let norm_2 = LayerNormConfig::new(config.d_model).init();
Expand Down
10 changes: 10 additions & 0 deletions burn-core/src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ pub struct TransformerEncoderConfig {
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
}

/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
Expand Down Expand Up @@ -169,6 +177,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
) -> Self {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.mha);
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);
Expand All @@ -189,6 +198,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
fn new(config: &TransformerEncoderConfig) -> Self {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let norm_1 = LayerNormConfig::new(config.d_model).init();
let norm_2 = LayerNormConfig::new(config.d_model).init();
Expand Down
21 changes: 21 additions & 0 deletions burn-tensor/src/tensor/activation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) ->
tensor.div(tensor_tmp)
}

/// Applies the "quiet softmax" function on the input tensor along the given dimension.
/// This function is similar to the softmax function, but it allows for "no selection", e.g.,
/// all outputs can tend to zero.
///
/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`
///
/// # Notes
///
/// The dimension argument `dim` specifies the dimension along which the function will be computed.
/// It must in the range of `0` and `D-1`.
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));

let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor = tensor.exp();
let tensor_tmp = tensor.clone().sum_dim(dim);
let one = Tensor::<B, D>::ones([1; D]).to_device(&tensor_tmp.device());

tensor.div(tensor_tmp + one)
}

/// Applies the log softmax function on the input tensor along the given dimension.
///
/// `log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`
Expand Down

0 comments on commit 17ce603

Please sign in to comment.