From 17ce603ade518c895daba2ba158cf3eecc8fce7b Mon Sep 17 00:00:00 2001 From: Will Brickner Date: Thu, 24 Aug 2023 16:58:59 -0500 Subject: [PATCH] Added quiet softmax activation + config option for attention related modules --- burn-core/src/nn/attention/mha.rs | 17 ++++++++++++++++- burn-core/src/nn/transformer/decoder.rs | 10 ++++++++++ burn-core/src/nn/transformer/encoder.rs | 10 ++++++++++ burn-tensor/src/tensor/activation/base.rs | 21 +++++++++++++++++++++ 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/burn-core/src/nn/attention/mha.rs b/burn-core/src/nn/attention/mha.rs index b1a0b3aa9e..549ec889ed 100644 --- a/burn-core/src/nn/attention/mha.rs +++ b/burn-core/src/nn/attention/mha.rs @@ -24,6 +24,14 @@ pub struct MultiHeadAttentionConfig { /// A value too low might result in NaN. #[config(default = -1.0e4)] min_float: f64, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + quiet_softmax: bool, } /// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -45,6 +53,7 @@ pub struct MultiHeadAttention { n_heads: usize, d_k: usize, min_float: f64, + quiet_softmax: bool, } /// [Multihead attention](MultiHeadAttention) forward pass input argument. @@ -72,6 +81,7 @@ impl MultiHeadAttentionConfig { n_heads: self.n_heads, d_k: self.d_model / self.n_heads, min_float: self.min_float, + quiet_softmax: self.quiet_softmax, } } @@ -95,6 +105,7 @@ impl MultiHeadAttentionConfig { n_heads: self.n_heads, d_k: self.d_model / self.n_heads, min_float: self.min_float, + quiet_softmax: self.quiet_softmax, } } } @@ -239,7 +250,11 @@ impl MultiHeadAttention { ); } - activation::softmax(attn_scores, 3) + if self.quiet_softmax { + activation::quiet_softmax(attn_scores, 3) + } else { + activation::softmax(attn_scores, 3) + } } fn attention_linear(&self, x: Tensor, linear: &nn::Linear) -> Tensor { diff --git a/burn-core/src/nn/transformer/decoder.rs b/burn-core/src/nn/transformer/decoder.rs index 432aa9dc91..8c750a6a3a 100644 --- a/burn-core/src/nn/transformer/decoder.rs +++ b/burn-core/src/nn/transformer/decoder.rs @@ -34,6 +34,14 @@ pub struct TransformerDecoderConfig { /// Layer norm will be applied first instead of after the other modules. #[config(default = false)] pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, } /// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -180,10 +188,12 @@ impl TransformerDecoderLayer { fn new(config: &TransformerDecoderConfig) -> Self { let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init(); let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init(); let norm_1 = LayerNormConfig::new(config.d_model).init(); let norm_2 = LayerNormConfig::new(config.d_model).init(); diff --git a/burn-core/src/nn/transformer/encoder.rs b/burn-core/src/nn/transformer/encoder.rs index bdfed82d90..449f7ca888 100644 --- a/burn-core/src/nn/transformer/encoder.rs +++ b/burn-core/src/nn/transformer/encoder.rs @@ -34,6 +34,14 @@ pub struct TransformerEncoderConfig { /// Layer norm will be applied first instead of after the other modules. #[config(default = false)] pub norm_first: bool, + /// Use "quiet softmax" instead of regular softmax. + /// + /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). + /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. + /// + /// Reference: + #[config(default = false)] + pub quiet_softmax: bool, } /// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). @@ -169,6 +177,7 @@ impl TransformerEncoderLayer { ) -> Self { let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init_with(record.mha); let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1); let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2); @@ -189,6 +198,7 @@ impl TransformerEncoderLayer { fn new(config: &TransformerEncoderConfig) -> Self { let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_dropout(config.dropout) + .with_quiet_softmax(config.quiet_softmax) .init(); let norm_1 = LayerNormConfig::new(config.d_model).init(); let norm_2 = LayerNormConfig::new(config.d_model).init(); diff --git a/burn-tensor/src/tensor/activation/base.rs b/burn-tensor/src/tensor/activation/base.rs index 78858cfb2e..7f95f79e85 100644 --- a/burn-tensor/src/tensor/activation/base.rs +++ b/burn-tensor/src/tensor/activation/base.rs @@ -31,6 +31,27 @@ pub fn softmax(tensor: Tensor, dim: usize) -> tensor.div(tensor_tmp) } +/// Applies the "quiet softmax" function on the input tensor along the given dimension. +/// This function is similar to the softmax function, but it allows for "no selection", e.g., +/// all outputs can tend to zero. +/// +/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]` +/// +/// # Notes +/// +/// The dimension argument `dim` specifies the dimension along which the function will be computed. +/// It must in the range of `0` and `D-1`. +pub fn quiet_softmax(tensor: Tensor, dim: usize) -> Tensor { + check!(TensorCheck::dim_ops::("softmax", dim)); + + let tensor = tensor.clone() - tensor.detach().max_dim(dim); + let tensor = tensor.exp(); + let tensor_tmp = tensor.clone().sum_dim(dim); + let one = Tensor::::ones([1; D]).to_device(&tensor_tmp.device()); + + tensor.div(tensor_tmp + one) +} + /// Applies the log softmax function on the input tensor along the given dimension. /// /// `log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`