Skip to content

Commit

Permalink
Qwen2VL - Implement get_rope_index
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 24, 2024
1 parent 95688fc commit 2e945a5
Showing 1 changed file with 217 additions and 51 deletions.
268 changes: 217 additions & 51 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,20 @@ import {

import {
cat,
full,
full_like,
mean,
zeros,
zeros_like,
ones,
ones_like,
full,
full_like,
stack,
std_mean,
Tensor,
zeros_like,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';

import { dynamic_time_warping, medianFilter } from './utils/maths.js';
import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis } from './env.js';
Expand Down Expand Up @@ -628,38 +629,53 @@ async function imageTextToTextForward(self, {
}

/**
* If the model supports providing position_ids, we create position_ids on the fly for batch generation,
* by computing the cumulative sum of the attention mask along the sequence length dimension.
*
* Equivalent to:
* Helper function to perform the following:
* ```python
* position_ids = attention_mask.long().cumsum(-1) - 1
* position_ids.masked_fill_(attention_mask == 0, 1)
* if past_key_values:
* position_ids = position_ids[:, -input_ids.shape[1] :]
* x = attention_mask.long().cumsum(-1) - 1
* x.masked_fill_(attention_mask == 0, 1)
* ```
* @param {Tensor} attention_mask
* @returns {{data: BigInt64Array, dims: number[]}}
*/
function createPositionIds(model_inputs, past_key_values = null) {

const { input_ids, inputs_embeds, attention_mask } = model_inputs;
function cumsum_masked_fill(attention_mask) {
const [bz, seq_len] = attention_mask.dims;
const attn_mask_data = attention_mask.data;

const data = new BigInt64Array(attention_mask.data.length);
const data = new BigInt64Array(attn_mask_data.length);
for (let i = 0; i < bz; ++i) {
const start = i * seq_len;
let sum = BigInt(0);
for (let j = 0; j < seq_len; ++j) {
const index = start + j;
if (attention_mask.data[index] === 0n) {
if (attn_mask_data[index] === 0n) {
data[index] = BigInt(1);
} else { // === 1n
data[index] = sum;
sum += attention_mask.data[index];
sum += attn_mask_data[index];
}
}
}
return { data, dims: attention_mask.dims };

}

/**
* If the model supports providing position_ids, we create position_ids on the fly for batch generation,
* by computing the cumulative sum of the attention mask along the sequence length dimension.
*
* Equivalent to:
* ```python
* position_ids = attention_mask.long().cumsum(-1) - 1
* position_ids.masked_fill_(attention_mask == 0, 1)
* if past_key_values:
* position_ids = position_ids[:, -input_ids.shape[1] :]
* ```
*/
function createPositionIds(model_inputs, past_key_values = null) {
const { input_ids, inputs_embeds, attention_mask } = model_inputs;

let position_ids = new Tensor('int64', data, attention_mask.dims);
const { data, dims } = cumsum_masked_fill(attention_mask);
let position_ids = new Tensor('int64', data, dims);
if (past_key_values) {
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
position_ids = position_ids.slice(null, [offset, null]);
Expand Down Expand Up @@ -4060,65 +4076,169 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
* - mrope_position_deltas: Tensor of shape `(batch_size)`.
*/
get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) {
if (video_grid_thw) {
throw new Error('`video_grid_thw` is not yet supported.');
}
const spatial_merge_size = this.config.vision_config.spatial_merge_size
const image_token_id = this.config.image_token_id
const video_token_id = this.config.video_token_id
const vision_start_token_id = this.config.vision_start_token_id
// @ts-ignore
const { vision_config, image_token_id, video_token_id, vision_start_token_id } = this.config;
const spatial_merge_size = vision_config.spatial_merge_size ?? 2;

const mrope_position_deltas = [];
if (image_grid_thw || video_grid_thw) {
let total_input_ids = input_ids.tolist();
if (!attention_mask) {
attention_mask = ones_like(input_ids);
}
let position_ids = [];

const attention_mask_list = attention_mask.tolist();
const position_ids_list = Array.from({ length: 3 }, _ => Array.from({ length: input_ids.dims[0] }, _ => Array.from({ length: input_ids.dims[1] }, _ => 1)));

for (let i = 0; i < total_input_ids.length; ++i) {
const ids = total_input_ids[i]
// .filter((x, idx) => attention_mask[i][idx] == 1);
const image_grid_thw_list = image_grid_thw ? image_grid_thw.tolist() : [];
const video_grid_thw_list = video_grid_thw ? video_grid_thw.tolist() : [];

const vision_start_indices = ids
.map((x, idx) => x == vision_start_token_id ? idx : -1)
.filter(x => x !== -1);
let image_index = 0;
let video_index = 0;
for (let i = 0; i < total_input_ids.length; ++i) {
const ids = total_input_ids[i].filter((_, j) => attention_mask_list[i][j] == 1);

const vision_start_indices = ids.reduce((acc, x, idx) => {
if (x == vision_start_token_id) acc.push(idx);
return acc;
}, []);

const vision_tokens = vision_start_indices.map(x => ids[x + 1]);
console.log({ ids, vision_start_indices, vision_tokens })

const image_nums = vision_tokens.filter(x => x == image_token_id).length;
const video_nums = vision_tokens.filter(x => x == video_token_id).length;
console.log({ image_nums, video_nums })


let llm_pos_ids_list = [];
let st = 0;
let remain_images = image_nums;
let remain_videos = video_nums;
for (let j = 0; j < vision_tokens.length; ++j) {
const next_image_token = ids.findIndex((x, i) => i > st && x == image_token_id);
const next_video_token = ids.findIndex((x, i) => i > st && x == video_token_id);

const ed_image = (remain_images > 0 && next_image_token !== -1)
? next_image_token
: ids.length + 1;

const ed_video = (remain_videos > 0 && next_video_token !== -1)
? next_video_token
: ids.length + 1;

let ed;
let t, h, w;
if (ed_image < ed_video) {
([t, h, w] = image_grid_thw_list[image_index]);
++image_index;
--remain_images;
ed = ed_image;
} else {
([t, h, w] = video_grid_thw_list[video_index]);
++video_index;
--remain_videos;
ed = ed_video;
}

// const a = ids.findIndex(x => x == image_token_id);
// for (let i = 0; i < vision_tokens.length; ++i) {
// if (remain_images > 0 && )
// }
const [llm_grid_t, llm_grid_h, llm_grid_w] = [
Number(t),
Math.floor(Number(h) / spatial_merge_size),
Math.floor(Number(w) / spatial_merge_size)
]
const text_len = ed - st;
const st_idx = llm_pos_ids_list.length > 0
? max(llm_pos_ids_list.at(-1))[0] + 1
: 0;

llm_pos_ids_list.push(
Array.from({ length: 3 * text_len }, (_, i) => st_idx + (i % text_len))
)

throw new Error('Not implemented');
}
const offset = text_len + st_idx;
const grid_size = llm_grid_t * llm_grid_h * llm_grid_w;
const t_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / (llm_grid_h * llm_grid_w)))
const h_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / llm_grid_w) % llm_grid_h)
const w_index = Array.from({ length: grid_size }, (_, i) => offset + i % llm_grid_w)

llm_pos_ids_list.push([t_index, h_index, w_index].flat())

st = ed + grid_size;
}

if (st < ids.length) {
const st_idx = llm_pos_ids_list.length > 0
? max(llm_pos_ids_list.at(-1))[0] + 1
: 0;
const text_len = ids.length - st;

llm_pos_ids_list.push(
Array.from({ length: 3 * text_len }, (_, i) => (st_idx + (i % text_len)))
)
}

const image_grid_thw_list = image_grid_thw.tolist();
// NOTE: Each item in llm_pos_ids_list is an array of shape (3, text_len),
// meaning to perform concatenation along dim=1, we can do the following:
const num_items = llm_pos_ids_list.reduce((acc, x) => acc + x.length, 0);
const llm_positions = new Array(num_items);
let index = 0;
for (let x = 0; x < 3; ++x) {
for (let y = 0; y < llm_pos_ids_list.length; ++y) {
const val = llm_pos_ids_list[y];
const text_len = val.length / 3;
for (let z = x * text_len; z < (x + 1) * text_len; ++z) {
llm_positions[index++] = val[z];
}
}
}

let count = 0;
const attn_mask = attention_mask_list[i];
for (let y = 0; y < attn_mask.length; ++y) {
if (attn_mask[y] == 1) {
for (let x = 0; x < 3; ++x) {
position_ids_list[x][i][y] = llm_positions[x * num_items / 3 + count];
}
++count;
}
}

const max_llm_positions = max(llm_positions)[0];
mrope_position_deltas.push(max_llm_positions + 1 - total_input_ids[i].length);
}

return [
toI64Tensor(position_ids),
ones([1]),
]
}
new Tensor('int64', position_ids_list.flat(Infinity), [3, input_ids.dims[0], input_ids.dims[1]]),
new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
];

throw new Error('Not yet implemented');
}
} else { // Text-only
if (attention_mask) {
const { data, dims } = cumsum_masked_fill(attention_mask);

const position_ids = BigInt64Array.from(
{ length: 3 * data.length },
(_, i) => data[i % data.length]
);
const mrope_position_deltas = Array.from(
{ length: dims[0] },
(_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1 + dims[1]
);

return [
new Tensor('int64', position_ids, [3, ...dims]),
new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
]
} else {
const [batch_size, seq_length] = input_ids.dims;
const position_ids = BigInt64Array.from(
{ length: 3 * batch_size * seq_length },
(_, i) => BigInt(Math.floor(i % seq_length / batch_size)),
);

return [
new Tensor('int64', position_ids, [3, ...input_ids.dims]),
zeros([batch_size, 1]),
]
}
}
}

async encode_image({ pixel_values, image_grid_thw }) {
const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, grid_thw: image_grid_thw })).image_features;
Expand All @@ -4131,10 +4251,56 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
input_ids,
attention_mask,
}) {
console.log('_merge_input_ids_with_image_features', { inputs_embeds, image_features, input_ids, attention_mask });
// @ts-ignore
const { image_token_id } = this.config;
const image_tokens = input_ids.tolist().map(ids =>
ids.reduce((acc, x, idx) => {
if (x == image_token_id) acc.push(idx);
return acc;
}, [])
);
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
const n_image_features = image_features.dims[0];
if (n_image_tokens !== n_image_features) {
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
}

// Equivalent to performing a masked_scatter
let img = 0;
for (let i = 0; i < image_tokens.length; ++i) {
const tokens = image_tokens[i];
const embeds = inputs_embeds[i];
for (let j = 0; j < tokens.length; ++j) {
embeds[tokens[j]].data.set(image_features[img++].data)
}
}
return { inputs_embeds, attention_mask }
}

prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
// Overwritten -- in specific circumstances we don't want to forward image inputs to the model
if (model_inputs.attention_mask && !model_inputs.position_ids) {
// Calculate position_ids and rope_deltas
if (!model_inputs.past_key_values) {
([model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
model_inputs.input_ids,
model_inputs.image_grid_thw,
model_inputs.video_grid_thw,
model_inputs.attention_mask,
));

} else {
model_inputs.pixel_values = null;
// model_inputs.pixel_values_videos = null;

const delta = BigInt(Object.values(model_inputs.past_key_values)[0].dims.at(-2));
const rope_deltas_list = model_inputs.rope_deltas.map(x => delta + x);
model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0)
}
}

return model_inputs;
}
}


Expand Down

0 comments on commit 2e945a5

Please sign in to comment.