An extension library to Candle that provides PyTorch functions not currently available in Candle
use candle_ext::{
candle::{ D, DType, Device, Result, Tensor},
TensorExt, F,
};
fn main() -> Result<()> {
let device = Device::Cpu;
let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;
let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;
Ok(())
}
Currently provides (see also tests):
-
F::scaled_dot_product_attention
-
F::chunk2..5 / Tensor::chunk2..5
-
F::cumsum / Tensor::cumsum
-
F::equal / Tensor::equal
-
F::eye / Tensor::eye
-
F::full / Tensor::full
-
F::full_like / Tensor::full_like
-
F::scatter / Tensor::scatter
-
F::triu / Tensor::triu
-
F::tril / Tensor::tril
-
F::masked_fill / Tensor::masked_fill
-
F::logical_not / Tensor::logical_not
-
F::logical_or / Tensor::logical_or
-
F::outer / Tensor::outer
-
F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5
Licensed under either of
- Apache License, Version 2.0, (LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or https://opensource.org/licenses/MIT)
at your option.
Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.