Skip to content

Commit

Permalink
allow models without scalars to be read without errors (#808)
Browse files Browse the repository at this point in the history
Co-authored-by: Corey Lowman <clowman1993@gmail.com>
  • Loading branch information
nkoppel and coreylowman authored Jul 12, 2023
1 parent b0fbaad commit e45a308
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 25 deletions.
29 changes: 19 additions & 10 deletions src/nn/npz.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
prelude::numpy::NpyError,
shapes::{Dtype, Shape},
tensor::{
numpy::{read_from_npz, write_to_npz, NpzError, NumpyDtype},
Expand Down Expand Up @@ -162,18 +163,26 @@ impl<R: Read + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_opts: ScalarOptions<N>,
opts: ScalarOptions<N>,
(n, full_path): (&mut N, String),
) -> Result<Option<N>, Self::Err> {
let buf: Vec<f64> = read_from_npz(self, &[], full_path)?;
*n = N::from(buf[0]).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {} to {} when reading from npz!",
buf[0],
std::any::type_name::<N>()
)
});
Ok(None)
match read_from_npz::<_, f64>(self, &[], full_path) {
Ok(buf) => {
*n = N::from(buf[0]).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {} to {} when reading from npz!",
buf[0],
std::any::type_name::<N>()
)
});
Ok(None)
}
Err(NpyError::IoError(e)) if e.kind() == std::io::ErrorKind::NotFound => {
*n = opts.default;
Ok(None)
}
Err(x) => Err(x.into()),
}
}
}

Expand Down
33 changes: 21 additions & 12 deletions src/nn/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,29 @@ impl<'data, E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for SafeTens

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_: ScalarOptions<N>,
opts: ScalarOptions<N>,
(n, full_path): (&mut N, String),
) -> Result<Option<N>, Self::Err> {
let data = self.tensor(&full_path)?.data();
let mut array = [0; 8];
array.copy_from_slice(data);
let val = f64::from_le_bytes(array);
*n = N::from(val).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!",
std::any::type_name::<N>()
)
});
Ok(None)
match self.tensor(&full_path) {
Ok(tensor) => {
let data = tensor.data();
let mut array = [0; 8];
array.copy_from_slice(data);
let val = f64::from_le_bytes(array);
*n = N::from(val).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!",
std::any::type_name::<N>()
)
});
Ok(None)
}
Err(SafeTensorError::TensorNotFound(_)) => {
*n = opts.default;
Ok(None)
}
Err(x) => Err(Error::SafeTensorError(x)),
}
}
}

Expand Down
13 changes: 10 additions & 3 deletions src/tensor/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,16 @@ pub(crate) fn read_from_npz<R: Read + Seek, E: Dtype + NumpyDtype>(
filename.push_str(".npy");
}

let mut f = r
.by_name(&filename)
.unwrap_or_else(|_| panic!("'{filename}' not found"));
let mut f = match r.by_name(&filename) {
Ok(f) => f,
Err(ZipError::FileNotFound) => {
return Err(NpyError::IoError(io::Error::new(
io::ErrorKind::NotFound,
ZipError::FileNotFound,
)))
}
Err(e) => panic!("Uncaught zip error: {e}"),
};

read_from_npy(&mut f, shape)
}
Expand Down

0 comments on commit e45a308

Please sign in to comment.