-
Notifications
You must be signed in to change notification settings - Fork 368
Parallel loading of the model tensors #79
Comments
Sort of related to speeding up loading, I've been messing around with rewriting it to use a This is what just loading the header and vocabulary looks like: pub mod mmap_loader {
use mmap_rs::{MmapFlags, MmapOptions};
#[allow(unused_imports)]
use nom::{
branch::alt,
bytes::complete as nby,
combinator as ncom,
error::ParseError,
multi as nm,
number::complete::{self as nnum, le_f32, le_i32, le_u32},
sequence as nseq, IResult, Parser, Slice,
};
use std::fs::File;
use super::*;
pub struct Flib;
#[derive(Debug)]
struct Header {
legacy: bool,
hyper: Hyperparameters,
}
impl Flib {
fn parse_header(i: &[u8]) -> IResult<&[u8], Header> {
let (i, magic) = le_i32(i)?;
let legacy = match magic {
ggml::FILE_MAGIC => false,
ggml::FILE_MAGIC_UNVERSIONED => true,
_ => return nom::error::context("ohno", ncom::fail)(i),
};
ncom::map(Flib::parse_hyperparameters, move |hyper| Header {
legacy,
hyper,
})(i)
}
fn parse_hyperparameters(i: &[u8]) -> IResult<&[u8], Hyperparameters> {
ncom::map(
nseq::tuple((le_i32, le_i32, le_i32, le_i32, le_i32, le_i32, le_i32)),
|(n_vocab, n_embd, n_mult, n_head, n_layer, n_rot, f16_)| Hyperparameters {
n_vocab,
n_ctx: 0,
n_embd,
n_mult,
n_head,
n_layer,
n_rot,
f16_,
},
)(i)
}
fn parse_vocabulary<'a>(i: &'a [u8], hdr: &Header) -> IResult<&'a [u8], Vocabulary> {
const TOKEN_PLACEHOLDER: &str = "�";
let n_vocab = hdr.hyper.n_vocab as usize;
let legacy = hdr.legacy;
let mut id_to_token = Vec::with_capacity(n_vocab);
let mut id_to_token_score = Vec::with_capacity(n_vocab);
let mut token_to_id = HashMap::with_capacity(n_vocab);
let vocabitem_parser = |i| {
nseq::tuple((nm::length_data(le_u32), ncom::cond(!legacy, le_f32)))(i)
.map(|(i, (sbytes, score))| (i, (sbytes, score.unwrap_or_default())))
};
let folf = |mut mtl: usize, (sbytes, score)| {
let tid = id_to_token.len();
let (ok, token) = std::str::from_utf8(sbytes).map_or_else(
|_| (false, TOKEN_PLACEHOLDER.to_string()),
|s| (true, s.to_string()),
);
if ok {
mtl = mtl.max(token.len());
token_to_id.insert(token.clone(), tid as TokenId);
}
id_to_token.push(token);
id_to_token_score.push(score);
mtl
};
let (i, max_token_length) =
nm::fold_many_m_n(n_vocab, n_vocab, vocabitem_parser, || 0, folf)(i)?;
IResult::Ok((
i,
Vocabulary {
id_to_token,
id_to_token_score,
token_to_id,
max_token_length,
},
))
}
pub fn load(path: impl AsRef<Path>) -> Result<(), LoadError> {
let path = path.as_ref();
let fp = File::open(path).map_err(|e| LoadError::OpenFileFailed {
source: e,
path: path.to_owned(),
})?;
let flen = fp.metadata()?.len();
let m = unsafe {
MmapOptions::new(flen as usize).and_then(|mo| {
mo.with_file(fp, 0)
.with_flags(MmapFlags::NO_CORE_DUMP)
.map()
})
}
.map_err(|e| LoadError::MmapFailed { source: e })?;
let mb = m.as_slice();
let (i, hdr) = Self::parse_header(mb).unwrap();
println!("Got: {hdr:?}");
let (i, vocab) = Self::parse_vocabulary(i, &hdr).unwrap();
println!(
"Got: {} - {} - {}",
vocab.max_token_length,
vocab.id_to_token.len(),
vocab.token_to_id.len()
);
Ok(())
}
}
} I honestly don't really love parsers in Rust, it's so much nicer in Haskell but I guess this is more readable than the current code. A long time ago, I experimented with trying to combine nom and monadic do type notation but it wasn't really practical: https://github.com/KerfuffleV2/mdoexperiments |
Along the lines of programmatic parsing, it might also be interesting to explore the use of https://github.com/jam1garner/binrw. Not sure how that would impact parallel loading or #93, though. |
Interesting. Weirdly enough, that actually only has limited support for non-streams (i.e. |
Don't really need mmap. smol+nuclei+2 fd should be enough. |
With mmap support I'm not sure how relevant this is now. It doesn't do much actual work when setting up the tensors. |
People have reported faster loading of the models in upstream when the tensors are loaded in parallel: ggerganov/llama.cpp#85
This should be pretty easy to do with Rust if we convert loading to an
iter
and then usepar_iter
instead. It seems like this should be I/O bound, but perhaps the actual loading process has computational overhead?The text was updated successfully, but these errors were encountered: