Skip to content

An extension library to Candle that provides PyTorch functions not currently available in Candle

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

mokeyish/candle-ext

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Candle Extensions

Test

An extension library to Candle that provides PyTorch functions not currently available in Candle

 use candle_ext::{
     candle::{ D, DType, Device, Result, Tensor},
     TensorExt, F,
 };

 fn main() -> Result<()> {
     let device = Device::Cpu;
     let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
     let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;

     let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;

     Ok(())
 }

Currently provides (see also tests):

  • F::scaled_dot_product_attention

  • F::chunk2..5 / Tensor::chunk2..5

  • F::cumsum / Tensor::cumsum

  • F::equal / Tensor::equal

  • F::eye / Tensor::eye

  • F::full / Tensor::full

  • F::full_like / Tensor::full_like

  • F::scatter / Tensor::scatter

  • F::triu / Tensor::triu

  • F::tril / Tensor::tril

  • F::masked_fill / Tensor::masked_fill

  • F::logical_not / Tensor::logical_not

  • F::logical_or / Tensor::logical_or

  • F::outer / Tensor::outer

  • F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5

License

Licensed under either of

at your option.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.

About

An extension library to Candle that provides PyTorch functions not currently available in Candle

Topics

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks