diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 9617d1aca..b1f568178 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -482,3 +482,171 @@ mod tests { assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } } + +pub trait ShapeWithOneHole { + fn into_shape(self, el_count: usize) -> Result; +} + +impl> ShapeWithOneHole for S { + fn into_shape(self, _el_count: usize) -> Result { + Ok(self.into()) + } +} + +impl ShapeWithOneHole for ((),) { + fn into_shape(self, el_count: usize) -> Result { + Ok(el_count.into()) + } +} + +impl ShapeWithOneHole for ((), usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((el_count / d1, d1).into()) + } +} + +impl ShapeWithOneHole for (usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, ()) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((d1, el_count / d1).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1, d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, (), d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, ()) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1, d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, (), d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, (), d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, d3, ()) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1, d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, (), d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, (), d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, d3, (), d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, d3, d4, ()) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, d4, el_count / d).into()) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 1eca694c7..6bb3d7409 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1685,12 +1685,15 @@ impl Tensor { Ok(from_storage(storage, shape, BackpropOp::none(), true)) } - // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// a new storage and copies the data over, the returned tensor is always contiguous. /// + /// The shape can be specified using a tuple of `usize` and at most one `()` in which case + /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so + /// as to match the number of elements in the tensor. + /// /// ```rust /// # use candle_core::{Tensor, DType, Device, D}; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; @@ -1700,10 +1703,14 @@ impl Tensor { /// /// let c = a.reshape((3, 2))?; /// assert_eq!(c.shape().dims(), &[3, 2]); + /// + /// let c = a.reshape((2, (), 1))?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn reshape>(&self, shape: S) -> Result { - let shape = shape.into(); + pub fn reshape(&self, s: S) -> Result { + let shape = s.into_shape(self.elem_count())?; if shape.elem_count() != self.elem_count() { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 4cc7e5fb0..64275fda2 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -57,10 +57,13 @@ fn main() -> Result<()> { #[allow(clippy::redundant_clone)] out_dir.clone() } - Ok(build_dir) => - { + Ok(build_dir) => { let path = PathBuf::from(build_dir); - path.canonicalize().expect(&format!("Directory doesn't exists: {} (the current directory is {})", &path.display(), std::env::current_dir()?.display())) + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) } }; set_cuda_include_dir()?;