Skip to content

Commit

Permalink
Implement for argmax too
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 3, 2024
1 parent 07e172e commit 5e86b81
Showing 1 changed file with 20 additions and 34 deletions.
54 changes: 20 additions & 34 deletions src/logits_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,70 +97,56 @@ impl LogitsProcessor {
}

fn sample_argmax(&mut self, logits: Tensor) -> Result<Logprobs> {
let mut logits_v: Vec<f32> = logits.to_vec1()?;
let mut probs: Vec<f32> = logits.to_vec1()?;
let argsort_indices = (0..probs.len()).collect::<Vec<_>>();

self.apply_logit_bias(&mut logits_v)?;
self.apply_logit_bias(&mut probs)?;

let next_token = logits_v
let next_token = probs
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i)
.unwrap();
let logprob = logits_v[next_token].log(10.0);
let tok = logits_v[next_token];

let mut sorted = logits_v.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
// Where the next token is in the sorted
let next_token_index = sorted
.binary_search_by(|w| {
if *w <= tok {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_err();
let logprob = probs[next_token].log(10.0);

let mut argsort_indices_sorted = argsort_indices.clone();
// Sort by descending prob
argsort_indices_sorted.sort_by(|a, b| probs[*b].partial_cmp(&probs[*a]).unwrap());
// These are where the top n are
let top_n_toks_range =
next_token_index.saturating_sub(self.top_n_logprobs)..next_token_index;
0..self.top_n_logprobs;
dbg!(&top_n_toks_range);
// The top n's values
let top_n_logprobs = sorted[top_n_toks_range]
let top_n_logprobs = argsort_indices_sorted[top_n_toks_range.clone()]
.iter()
.map(|x| x.log(10.0))
.map(|x| probs[*x].log(10.0))
.collect::<Vec<_>>();
// Find where they actually are in the logits
let mut top_n_toks = Vec::new();
for val in top_n_logprobs.iter() {
let idx = logits_v
.binary_search_by(|w| {
if *w <= *val {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_err();
top_n_toks.push(idx as u32);
for val in top_n_toks_range {
top_n_toks.push(argsort_indices[val]);
}

let mut bytes = Vec::new();
dbg!(&top_n_toks);
for tok in &top_n_toks {
bytes.push(
self.tokenizer
.decode(&[*tok], true)
.decode(&[*tok as u32], true)
.map_err(|x| Error::Msg(x.to_string()))?,
);
}
let top_logprobs = zip(bytes, zip(top_n_toks, top_n_logprobs))
.map(|(bytes, (token, logprob))| TopLogprob {
token,
token: token as u32,
logprob,
bytes,
})
.collect::<Vec<_>>();

dbg!(next_token);

Ok(Logprobs {
token: next_token as u32,
logprob,
Expand Down

0 comments on commit 5e86b81

Please sign in to comment.