Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various small fixes #814

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions examples/safetensors-save-load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
#[cfg(feature = "safetensors")]
fn main() {
use ::safetensors::SafeTensors;
use dfdx::{
prelude::*,
tensor::{AsArray, AutoDevice, Cpu},
};
use dfdx::prelude::*;
use memmap2::MmapOptions;
let dev: Cpu = Default::default();

Expand Down
7 changes: 5 additions & 2 deletions src/shapes/shape.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::{axes::*, ReduceShape, ReduceShapeTo};

#[cfg(feature = "f16")]
pub use half::f16;

#[cfg(not(feature = "cuda"))]
pub trait SafeZeros {}

Expand Down Expand Up @@ -48,7 +51,7 @@ unit!(u128, 1);
unit!(i128, 1);
unit!(bool, true);
#[cfg(feature = "f16")]
unit!(half::f16, half::f16::ONE);
unit!(f16, f16::ONE);

/// Represents something that has a [Unit].
pub trait HasUnitType {
Expand Down Expand Up @@ -88,7 +91,7 @@ impl Dtype for u64 {}
impl Dtype for u128 {}
impl Dtype for usize {}
#[cfg(feature = "f16")]
impl Dtype for half::f16 {}
impl Dtype for f16 {}

/// Represents something that has a [Dtype].
pub trait HasDtype {
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/cpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ impl Cache for Cpu {
debug_assert_eq!(std::alloc::Layout::new::<u32>().align(), 4);
debug_assert_eq!(std::alloc::Layout::new::<u64>().align(), 8);
match key.alignment {
1 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u8, len, cap)) },
1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) },
2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) },
4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) },
8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) },
Expand Down
6 changes: 1 addition & 5 deletions src/tensor/cpu/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ pub(crate) fn index_to_i<S: Shape>(shape: &S, strides: &S::Concrete, index: S::C
panic!("Index out of bounds: index={index:?} shape={shape:?}");
}
}
strides
.into_iter()
.zip(index.into_iter())
.map(|(a, b)| a * b)
.sum()
strides.into_iter().zip(index).map(|(a, b)| a * b).sum()
}

impl<S: Shape, E: Unit, T> std::ops::Index<S::Concrete> for Tensor<S, E, Cpu, T> {
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/ghost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct GhostTensor<S: Shape, E, D: Storage<E>> {
impl<S: Shape, E, D: Storage<E>, T> Tensor<S, E, D, T> {
/// Creates a ghost tensor that doesn't hold a reference
/// to the tensor's data.
pub(crate) fn ghost(&self) -> GhostTensor<S, E, D> {
pub fn ghost(&self) -> GhostTensor<S, E, D> {
GhostTensor {
id: self.id,
len: self.device.len(&self.data),
Expand Down
26 changes: 11 additions & 15 deletions src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,16 @@ impl<E, D: Storage<E>> Gradients<E, D> {

impl<E, D: Storage<E>> Gradients<E, D> {
/// Retrieves mutable gradient for `t`, allocating one if it isn't present.
pub(crate) fn get_or_alloc_mut<S: Shape>(
pub fn get_or_alloc_mut<S: Shape>(
&mut self,
t: &Tensor<S, E, D>,
t: &impl Tensorlike<S, E, D>,
) -> Result<&mut D::Vec, D::Err> {
let ghost = t.ghost();
self.try_alloc_for(&ghost)?;
Ok(self.get_mut(&ghost))
self.try_alloc_for(t)?;
Ok(self.get_mut(t))
}

/// Inserts a gradient for `t`
pub(crate) fn try_alloc_for<S: Shape>(
&mut self,
t: &impl Tensorlike<S, E, D>,
) -> Result<(), D::Err> {
pub fn try_alloc_for<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> Result<(), D::Err> {
if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id()) {
e.insert(t.try_alloc_grad()?);
}
Expand Down Expand Up @@ -92,7 +88,7 @@ impl<E, D: Storage<E>> Gradients<E, D> {
self.gradient_by_id.get_mut(&t.id()).unwrap()
}

/// Returns a mutable reference to the data associated with `t`.
/// Returns an immutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_ref<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &D::Vec {
Expand All @@ -104,14 +100,14 @@ impl<E, D: Storage<E>> Gradients<E, D> {
/// # Panics
/// If no data is associated with `t` yet, this will panic due to an unwrap()
/// on a .get() to the underlying hashmap.
pub fn get<S: Shape, T>(&self, t: &Tensor<S, E, D, T>) -> Tensor<S, E, D> {
let buf = self.gradient_by_id.get(&t.id).unwrap().clone();
pub fn get<S: Shape>(&self, t: &impl Tensorlike<S, E, D>) -> Tensor<S, E, D> {
let buf = self.gradient_by_id.get(&t.id()).unwrap().clone();
Tensor {
id: unique_id(),
data: std::sync::Arc::new(buf),
shape: t.shape,
strides: t.strides,
device: t.device.clone(),
shape: *t.shape(),
strides: t.strides(),
device: t.dev().clone(),
tape: Default::default(),
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/tensor_ops/matmul/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ impl MatMulImpl<half::f16> for Cpu {
m.size(),
n.size(),
k.size(),
cp as *mut gemm::f16,
cp,
cstr[1] as isize,
cstr[0] as isize,
accum,
ap as *const gemm::f16,
ap,
astr[1] as isize,
astr[0] as isize,
bp as *const gemm::f16,
bp,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need these - internally gemm does a TypeId check against the type. So if gemm is using a different version of the half crate than dfdx, this will fail.

bstr[1] as isize,
bstr[0] as isize,
if accum {
Expand Down