Skip to content

Commit

Permalink
Add support for Qwen2VLProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 20, 2024
1 parent 76c132f commit 6146f0b
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/base/processing_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,18 @@ export class Processor extends Callable {
get feature_extractor() {
return this.components.feature_extractor;
}


apply_chat_template(messages, options = {}) {
if (!this.tokenizer) {
throw new Error('Unable to apply chat template without a tokenizer.');
}
return this.tokenizer.apply_chat_template(messages, {
tokenize: false, // default to false
...options,
});
}


/**
* Calls the feature_extractor function with the given input.
* @param {any} input The input to extract features from.
Expand Down Expand Up @@ -110,15 +121,15 @@ export class Processor extends Callable {
const [config, components] = await Promise.all([
// TODO:
this.uses_processor_config
? getModelJSON(pretrained_model_name_or_path, PROCESSOR_NAME, true, options)
: {},
? getModelJSON(pretrained_model_name_or_path, PROCESSOR_NAME, true, options)
: {},
Promise.all(
this.classes
.filter((cls) => cls in this)
.map(async (cls) => {
const component = await this[cls].from_pretrained(pretrained_model_name_or_path, options);
return [cls.replace(/_class$/,''), component];
})
.filter((cls) => cls in this)
.map(async (cls) => {
const component = await this[cls].from_pretrained(pretrained_model_name_or_path, options);
return [cls.replace(/_class$/, ''), component];
})
).then(Object.fromEntries)
]);

Expand Down
1 change: 1 addition & 0 deletions src/models/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export * from './mgp_str/processing_mgp_str.js';
export * from './janus/processing_janus.js';
export * from './owlvit/processing_owlvit.js';
export * from './pyannote/processing_pyannote.js';
export * from './qwen2_vl/processing_qwen2_vl.js';
export * from './sam/processing_sam.js';
export * from './speecht5/processing_speecht5.js';
export * from './wav2vec2/processing_wav2vec2.js';
Expand Down
52 changes: 52 additions & 0 deletions src/models/qwen2_vl/processing_qwen2_vl.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { Processor } from "../../base/processing_utils.js";
import { AutoImageProcessor } from "../auto/image_processing_auto.js";
import { AutoTokenizer } from "../../tokenizers.js";
import { RawImage } from "../../utils/image.js";

export class Qwen2VLProcessor extends Processor {
static image_processor_class = AutoImageProcessor
static tokenizer_class = AutoTokenizer

/**
*
* @param {string|string[]} text
* @param {RawImage|RawImage[]} images
* @param {...any} args
* @returns {Promise<any>}
*/
async _call(text, images = null, ...args) {

if (!Array.isArray(text)) {
text = [text];
}

let image_inputs, image_grid_thw;

if (images) {
image_inputs = await this.image_processor(images);
image_grid_thw = image_inputs.image_grid_thw;
}

if (image_grid_thw) {
let merge_length = this.image_processor.config.merge_size ** 2;
let index = 0;

const image_grid_thw_list = image_grid_thw.tolist();
text = text.map(t => {
while (t.includes("<|image_pad|>")) {
const prod = image_grid_thw_list[index++].reduce((a, b) => a * b, 1);
t = t.replace("<|image_pad|>", "<|placeholder|>".repeat(Math.floor(prod / merge_length)));
}
return t.replaceAll("<|placeholder|>", "<|image_pad|>");
});
}

const text_inputs = this.tokenizer(text);

return {
...text_inputs,
...image_inputs,
// TODO: ...videos_inputs,
}
}
}
36 changes: 36 additions & 0 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -1038,5 +1038,41 @@ describe("Processors", () => {
},
MAX_TEST_EXECUTION_TIME,
);

describe(
"Qwen2VLProcessor",
() => {
/** @type {import('../src/transformers.js').Qwen2VLProcessor} */
let processor;
let images = {};

beforeAll(async () => {
processor = await AutoProcessor.from_pretrained(MODELS.qwen2_vl);
images = {
white_image: await load_image(TEST_IMAGES.white_image),
};
});

it("Image and text", async () => {
const conversation = [
{
role: "user",
content: [{ type: "image" }, { type: "text", text: "Describe this image." }],
},
];

const text = processor.apply_chat_template(conversation, {
add_generation_prompt: true,
});
const { input_ids, attention_mask, pixel_values, image_grid_thw } = await processor(text, images.white_image);

compare(input_ids.dims, [1, 89]);
compare(attention_mask.dims, [1, 89]);
compare(pixel_values.dims, [256, 1176]);
compare(image_grid_thw.dims, [1, 3]);
});
},
MAX_TEST_EXECUTION_TIME,
);
});
});

0 comments on commit 6146f0b

Please sign in to comment.