Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding nn.Embedding layer. #406

Merged
merged 4 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 247 additions & 0 deletions src/nn/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*};

use super::module::{BuildModule, Module, ModuleMut, ResetParams, ToDevice};

/// An embedding
/// Initializes [Self::weight] from a Uniform distribution
/// between [-1 / sqrt(I), 1 / sqrt(I)].
///
/// # Generics
/// - `VOCAB` The size of the vocabulary, inputs integer values must be between
/// 0 and VOCAB;
/// - `DIM` The "output" size of vectors & matrices which are the vectors being selected.
///
/// # Examples
/// `Embedding<5, 2>` can act on vectors with SEQ integer elements (with values between 0 and 4), and results in a SEQ tensor of
/// usually f32 elements being the rows in [Self::weight].
/// ```rust
///
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let mut model: Embedding<7, 2> = BuildModule::build(&dev);
/// // single sequence of ids
/// let inputs: Tensor<_, usize> = dev.zeros::<Rank1<5>>();
/// let _: Tensor<(Const<5>, Const<2>,), f32> = model.forward(inputs);
/// // batched sequence of ids
/// let inputs: Tensor<_, usize> = dev.zeros::<Rank2<10, 5>>();
/// let _: Tensor<(Const<10>, Const<5>, Const<2>), f32> = model.forward(inputs);
/// ```
#[derive(Debug, Clone)]
pub struct Embedding<const VOCAB: usize, const DIM: usize, D: Device<f32> = Cpu> {
/// Transposed weight matrix, shape (I, O)
pub weight: Tensor<Rank2<VOCAB, DIM>, f32, D>,
}

impl<const VOCAB: usize, const DIM: usize, const SEQ: usize, D: Device<f32>, T: Tape<D>>
Module<Tensor<Rank1<SEQ>, usize, D, T>> for Embedding<VOCAB, DIM, D>
{
type Output = Tensor<Rank2<SEQ, DIM>, f32, D, T>;
fn forward(&self, input: Tensor<Rank1<SEQ>, usize, D, T>) -> Self::Output {
let (input, tape) = input.split_tape();
self.weight.clone().put_tape(tape).gather(input)
}
}

impl<
const VOCAB: usize,
const DIM: usize,
const SEQ: usize,
const BATCH: usize,
D: Device<f32>,
T: Tape<D>,
> Module<Tensor<Rank2<BATCH, SEQ>, usize, D, T>> for Embedding<VOCAB, DIM, D>
{
type Output = Tensor<Rank3<BATCH, SEQ, DIM>, f32, D, T>;
fn forward(&self, input: Tensor<Rank2<BATCH, SEQ>, usize, D, T>) -> Self::Output {
let (input, tape) = input.split_tape();
self.weight.clone().put_tape(tape).gather(input)
}
}

impl<T, const VOCAB: usize, const DIM: usize, D: Device<f32>> ModuleMut<T>
for Embedding<VOCAB, DIM, D>
where
Self: Module<T>,
{
type Output = <Self as Module<T>>::Output;
fn forward_mut(&mut self, input: T) -> Self::Output {
self.forward(input)
}
}

impl<const VOCAB: usize, const DIM: usize, D: Device<f32>> GradientUpdate<D, f32>
for Embedding<VOCAB, DIM, D>
{
fn update<U>(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err>
where
U: ParamUpdater<D, f32>,
{
self.weight.update(updater, unused)?;
Ok(())
}
}

impl<const VOCAB: usize, const DIM: usize, D: Device<f32>> ResetParams<D, f32>
for Embedding<VOCAB, DIM, D>
{
fn try_reset_params(&mut self) -> Result<(), D::Err> {
let bound: f32 = 1.0 / (VOCAB as f32).sqrt();
let distr = rand_distr::Uniform::new(-bound, bound);
self.weight.try_fill_with_distr(distr)?;
Ok(())
}
}

impl<const VOCAB: usize, const DIM: usize, D: Device<f32>> BuildModule<D, f32>
for Embedding<VOCAB, DIM, D>
{
fn try_build(device: &D) -> Result<Self, D::Err> {
let bound: f32 = 1.0 / (VOCAB as f32).sqrt();
let distr = rand_distr::Uniform::new(-bound, bound);
let weight = device.try_sample(distr)?;
Ok(Self { weight })
}
}

impl<const VOCAB: usize, const DIM: usize, D1: Device<f32>, D2: Device<f32>> ToDevice<D2>
for Embedding<VOCAB, DIM, D1>
{
type Output = Embedding<VOCAB, DIM, D2>;
fn to_device(&self, device: &D2) -> Self::Output {
Embedding {
weight: self.weight.to_device(device),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
nn::{tests::SimpleUpdater, BuildModule},
tests::{assert_close, TestDevice},
unique_id::HasUniqueId,
};

const W: [[f32; 5]; 2] = [
[-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966],
[0.11733949, 0.14059687, -0.10670426, -0.09373143, 0.18974298],
];

#[test]
fn test_embedding_initialize() {
let dev: TestDevice = Default::default();
let m: Embedding<2000, 1> = BuildModule::build(&dev);
let bound = 1.0 / 2000.0f32.sqrt();
for v in m.weight.as_vec() {
assert!(-bound <= v && v <= bound && v != 0.0);
}
}

#[test]
fn embedding_forward_1d() {
let dev: TestDevice = Default::default();

let model = Embedding {
weight: dev.tensor(W),
};

let x = dev.tensor([0, 0, 1]);
let y = model.forward(x.trace());
assert_close(
&y.array(),
&[
[-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966],
[-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966],
[0.11733949, 0.14059687, -0.10670426, -0.09373143, 0.18974298],
],
);

let g = y.square().mean().backward();
assert_close(
&g.get(&model.weight).array(),
&[
[
-0.09223715,
-0.08099073,
-0.09898819,
0.03814289,
-0.007172427,
],
[
0.015645266,
0.01874625,
-0.014227235,
-0.012497525,
0.025299065,
],
],
);
}

#[test]
fn test_forward_2d() {
let dev: TestDevice = Default::default();

let model = Embedding {
weight: dev.tensor(W),
};

let x = dev.tensor([[0, 0], [0, 1]]);
let y = model.forward(x.trace());
assert_close(
&y.array(),
&[
[
[-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966],
[-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966],
],
[
[-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966],
[0.11733949, 0.14059687, -0.10670426, -0.09373143, 0.18974298],
],
],
);

let g = y.square().mean().backward();
assert_close(
&g.get(&model.weight).array(),
&[
[
-0.103766784,
-0.091114566,
-0.11136171,
0.042910747,
-0.008068981,
],
[
0.011733949,
0.014059687,
-0.010670426,
-0.009373143,
0.018974299,
],
],
);
}

#[test]
fn test_embedding_missing_gradients() {
let dev: TestDevice = Default::default();

let mut model: Embedding<5, 3> = BuildModule::build(&dev);
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
model.update(&mut g, &mut unused).unwrap();
assert_eq!(&unused.ids, &[*model.weight.id()]);

g.0.try_alloc_for(&model.weight).unwrap();

// weight gradient is present
let mut unused = Default::default();
model.update(&mut g, &mut unused).unwrap();
assert!(unused.is_empty());
}
}
2 changes: 2 additions & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ mod add_into;
mod batchnorm2d;
mod conv;
mod dropout;
mod embedding;
mod flatten;
mod generalized_residual;
mod impl_module_for_tuples;
Expand All @@ -123,6 +124,7 @@ pub use activations::*;
pub use add_into::*;
pub use batchnorm2d::*;
pub use dropout::*;
pub use embedding::*;
pub use generalized_residual::*;
pub use impl_module_for_tuples::*;
pub use layer_norm::*;
Expand Down