/candle-ext

An extension library to Candle that provides PyTorch functions not currently available in Candle

Primary LanguageRustApache License 2.0Apache-2.0

Candle Extensions

Test

An extension library to Candle that provides PyTorch functions not currently available in Candle

 use candle_ext::{
     candle::{ D, DType, Device, Result, Tensor},
     TensorExt, F,
 };

 fn main() -> Result<()> {
     let device = Device::Cpu;
     let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
     let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;

     let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;

     Ok(())
 }

Currently provides (see also tests):

  • F::scaled_dot_product_attention

  • F::chunk2..5 / Tensor::chunk2..5

  • F::cumsum / Tensor::cumsum

  • F::equal / Tensor::equal

  • F::eye / Tensor::eye

  • F::full / Tensor::full

  • F::full_like / Tensor::full_like

  • F::scatter / Tensor::scatter

  • F::triu / Tensor::triu

  • F::tril / Tensor::tril

  • F::masked_fill / Tensor::masked_fill

  • F::logical_not / Tensor::logical_not

  • F::logical_or / Tensor::logical_or

  • F::outer / Tensor::outer

  • F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5

License

Licensed under either of

at your option.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.