forked from PaddlePaddle/PaddleMIX
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PPMix No.12] Support Janus and JanusFlow inference (PaddlePaddle#839)
Add examples of inference from unified multimodal understanding and generation models, including Janus and JanusFlow. --------- Co-authored-by: nemonameless@qq.com@github.com <nemonameless@qq.com>
- Loading branch information
1 parent
e1e3620
commit 8bd1cba
Showing
18 changed files
with
3,850 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Janus/JanusFlow | ||
|
||
## 1. 模型介绍 | ||
|
||
[Janus/JanusFlow](https://github.com/deepseek-ai/Janus) 将视觉编码解耦到单独的路径中,同时仍然使用单个统一的转换器架构进行处理,解决了以前方法的局限性。解耦不仅缓解了视觉编码器在理解和生成中的角色冲突,还增强了框架的灵活性。 | ||
|
||
|
||
**本仓库支持的模型权重:** | ||
|
||
| Model | | ||
|--------------------| | ||
| deepseek-ai/Janus-1.3B | | ||
| deepseek-ai/JanusFlow-1.3B | | ||
|
||
注意:与huggingface权重同名,但权重为paddle框架的Tensor,使用`xxx.from_pretrained("deepseek-ai/Janus-1.3B")`即可自动下载该权重文件夹到缓存目录。 | ||
|
||
|
||
## 2 环境准备 | ||
|
||
1)[安装 PaddleMIX 环境依赖包](https://github.com/PaddlePaddle/PaddleMIX/tree/b4f97ff859e1964c839fc5fab94f7ba63b1e5959?tab=readme-ov-file#%E5%AE%89%E8%A3%85) | ||
|
||
2) pip install pillow tqdm paddlenlp==3.0.0b2 | ||
注意:Python版本最好为3.10及以上版本。 | ||
|
||
## 3 快速开始 | ||
|
||
### 推理 | ||
```bash | ||
# Janus understanding | ||
python paddlemix/examples/janus/run_understanding_inference.py \ | ||
--model_path="deepseek-ai/Janus-1.3B" \ | ||
--image_file="paddlemix/demo_images/examples_image1.jpg" \ | ||
--question="What is shown in this image?" \ | ||
|
||
# Janus generation | ||
python paddlemix/examples/janus/run_generation_inference.py \ | ||
--model_path="deepseek-ai/Janus-1.3B" \ | ||
--prompt="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair" | ||
|
||
# JanusFlow generation | ||
python paddlemix/examples/janus/run_generation_inference_janusflow.py \ | ||
--model_path="deepseek-ai/JanusFlow-1.3B" \ | ||
--inference_step=30 \ | ||
--prompt="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair" | ||
|
||
# Janus interactivechat | ||
python paddlemix/examples/janus/run_interactivechat.py \ | ||
--model_path="deepseek-ai/Janus-1.3B" \ | ||
|
||
``` | ||
|
||
### 参考文献 | ||
```BibTeX | ||
@article{wu2024janus, | ||
title={Janus: Decoupling visual encoding for unified multimodal understanding and generation}, | ||
author={Wu, Chengyue and Chen, Xiaokang and Wu, Zhiyu and Ma, Yiyang and Liu, Xingchao and Pan, Zizheng and Liu, Wen and Xie, Zhenda and Yu, Xingkai and Ruan, Chong and others}, | ||
journal={arXiv preprint arXiv:2410.13848}, | ||
year={2024} | ||
} | ||
@misc{ma2024janusflow, | ||
title={JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation}, | ||
author={Yiyang Ma and Xingchao Liu and Xiaokang Chen and Wen Liu and Chengyue Wu and Zhiyu Wu and Zizheng Pan and Zhenda Xie and Haowei Zhang and Xingkai yu and Liang Zhao and Yisong Wang and Jiaying Liu and Chong Ruan}, | ||
journal={arXiv preprint arXiv:2411.07975}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pillow | ||
tqdm | ||
paddlenlp==3.0.0b2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
import paddle | ||
import PIL.Image | ||
from paddlenlp.transformers import LlamaTokenizerFast | ||
from tqdm import tqdm | ||
|
||
from paddlemix.models.janus import JanusMultiModalityCausalLM | ||
from paddlemix.processors import JanusImageProcessor, JanusVLChatProcessor | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str, default="deepseek-ai/Janus-1.3B") | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
default="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair", | ||
) | ||
args = parser.parse_args() | ||
|
||
vl_gpt = JanusMultiModalityCausalLM.from_pretrained(args.model_path) | ||
tokenizer = LlamaTokenizerFast.from_pretrained(args.model_path) | ||
image_processer = JanusImageProcessor.from_pretrained(args.model_path) | ||
vl_chat_processor: JanusVLChatProcessor = JanusVLChatProcessor(image_processer, tokenizer) | ||
|
||
conversation = [ | ||
{ | ||
"role": "User", | ||
"content": args.prompt, | ||
}, | ||
{"role": "Assistant", "content": ""}, | ||
] | ||
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | ||
conversations=conversation, sft_format=vl_chat_processor.sft_format, system_prompt="" | ||
) | ||
prompt = sft_format + vl_chat_processor.image_start_tag | ||
|
||
|
||
@paddle.no_grad() | ||
def generate( | ||
mmgpt, | ||
vl_chat_processor, | ||
prompt: str, | ||
temperature: float = 1, | ||
parallel_size: int = 2, | ||
cfg_weight: float = 5, | ||
image_token_num_per_image: int = 576, | ||
img_size: int = 384, | ||
patch_size: int = 16, | ||
): | ||
input_ids = vl_chat_processor.tokenizer.encode(prompt) | ||
input_ids = paddle.to_tensor(data=input_ids.input_ids, dtype="int64") | ||
tokens = paddle.zeros(shape=(parallel_size * 2, len(input_ids)), dtype="int32") | ||
for i in range(parallel_size * 2): | ||
tokens[i, :] = input_ids | ||
if i % 2 != 0: | ||
tokens[i, 1:-1] = vl_chat_processor.pad_id | ||
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) # [4, 50, 2048] | ||
generated_tokens = paddle.zeros(shape=(parallel_size, image_token_num_per_image), dtype="int32") | ||
batch_size, seq_length = inputs_embeds.shape[:2] | ||
for i in tqdm(range(image_token_num_per_image)): | ||
batch_size, seq_length = inputs_embeds.shape[:2] | ||
|
||
past_key_values_length = outputs.past_key_values[0][0].shape[1] if i != 0 else 0 | ||
position_ids = paddle.arange(past_key_values_length, seq_length + past_key_values_length).expand( | ||
(batch_size, seq_length) | ||
) | ||
|
||
outputs = mmgpt.language_model.llama( | ||
position_ids=position_ids, | ||
inputs_embeds=inputs_embeds, # [4, 1, 2048] | ||
use_cache=True, | ||
past_key_values=outputs.past_key_values if i != 0 else None, | ||
return_dict=True, | ||
) | ||
|
||
hidden_states = outputs.last_hidden_state | ||
logits = mmgpt.gen_head(hidden_states[:, -1, :]) | ||
logit_cond = logits[0::2, :] | ||
logit_uncond = logits[1::2, :] | ||
|
||
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) | ||
probs = paddle.nn.functional.softmax(x=logits / temperature, axis=-1) | ||
next_token = paddle.multinomial(x=probs, num_samples=1) | ||
|
||
generated_tokens[:, i] = next_token.squeeze(axis=-1) | ||
next_token = paddle.concat(x=[next_token.unsqueeze(axis=1), next_token.unsqueeze(axis=1)], axis=1).reshape( | ||
[-1] | ||
) | ||
img_embeds = mmgpt.prepare_gen_img_embeds(next_token) | ||
inputs_embeds = img_embeds.unsqueeze(axis=1) | ||
|
||
dec = mmgpt.gen_vision_model.decode_code( | ||
generated_tokens.to(dtype="int32"), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size] | ||
) | ||
dec = dec.to("float32").cpu().numpy().transpose(0, 2, 3, 1) | ||
dec = np.clip((dec + 1) / 2 * 255, 0, 255) | ||
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) | ||
visual_img[:, :, :] = dec | ||
os.makedirs("janus_generated_samples", exist_ok=True) | ||
for i in range(parallel_size): | ||
save_path = os.path.join("janus_generated_samples", "img_{}.jpg".format(i)) | ||
PIL.Image.fromarray(visual_img[i]).save(save_path) | ||
|
||
|
||
generate(vl_gpt, vl_chat_processor, prompt) |
128 changes: 128 additions & 0 deletions
128
paddlemix/examples/janus/run_generation_inference_janusflow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
import paddle | ||
import PIL.Image | ||
from paddlenlp.transformers import LlamaTokenizerFast | ||
from tqdm import tqdm | ||
|
||
from paddlemix.models.janus import JanusFlowMultiModalityCausalLM | ||
from paddlemix.processors import JanusImageProcessor, JanusVLChatProcessor | ||
from ppdiffusers.models import AutoencoderKL | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str, default="deepseek-ai/JanusFlow-1.3B") | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
default="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair", | ||
) | ||
parser.add_argument("--inference_step", type=int, default=30) | ||
|
||
args = parser.parse_args() | ||
|
||
vl_gpt = JanusFlowMultiModalityCausalLM.from_pretrained(args.model_path) | ||
tokenizer = LlamaTokenizerFast.from_pretrained(args.model_path) | ||
image_processer = JanusImageProcessor.from_pretrained(args.model_path) | ||
vl_chat_processor: JanusVLChatProcessor = JanusVLChatProcessor(image_processer, tokenizer) | ||
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") | ||
|
||
conversation = [ | ||
{ | ||
"role": "User", | ||
"content": args.prompt, | ||
}, | ||
{"role": "Assistant", "content": ""}, | ||
] | ||
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | ||
conversations=conversation, sft_format=vl_chat_processor.sft_format, system_prompt="" | ||
) | ||
prompt = sft_format + vl_chat_processor.image_start_tag | ||
|
||
|
||
@paddle.no_grad() | ||
def generate( | ||
vl_gpt, | ||
vl_chat_processor, | ||
tokenizer, | ||
prompt, | ||
cfg_weight: float = 2.0, | ||
num_inference_steps: int = 30, | ||
batch_size: int = 1, | ||
): | ||
input_ids = tokenizer(prompt, return_tensors="pd")["input_ids"] | ||
tokens = paddle.stack(x=[input_ids] * batch_size * 2)[:, 0, :] | ||
tokens[batch_size:, 1:] = vl_chat_processor.pad_id | ||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) | ||
inputs_embeds = inputs_embeds[:, :-1, :] | ||
z = paddle.randn(shape=(batch_size, 4, 48, 48), dtype="bfloat16") | ||
dt = 1.0 / num_inference_steps | ||
dt = paddle.zeros_like(x=z, dtype="bfloat16") + dt | ||
attention_mask = paddle.ones(shape=(2 * batch_size, tuple(inputs_embeds.shape)[1] + 577)) | ||
attention_mask[batch_size:, 1 : tuple(inputs_embeds.shape)[1]] = 0 | ||
attention_mask = attention_mask.astype(dtype="int32") | ||
|
||
for step in tqdm(range(num_inference_steps)): | ||
z_input = paddle.concat(x=[z, z], axis=0) | ||
t = step / num_inference_steps * 1000.0 | ||
t = paddle.to_tensor(data=[t] * tuple(z_input.shape)[0]).to(dt.place) | ||
z_enc = vl_gpt.vision_gen_enc_model(z_input, t) | ||
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2] | ||
z_emb = z_emb.reshape([tuple(z_emb.shape)[0], tuple(z_emb.shape)[1], -1]).transpose(perm=[0, 2, 1]) | ||
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb) | ||
llm_emb = ( | ||
paddle.concat(x=[inputs_embeds, t_emb.unsqueeze(axis=1), z_emb], axis=1) | ||
if step == 0 | ||
else paddle.concat(x=[t_emb.unsqueeze(axis=1), z_emb], axis=1) | ||
) | ||
bs, seq_len, dim = llm_emb.shape | ||
past_seen_tokens = inputs_embeds.shape[1] if step != 0 else 0 | ||
position_ids = paddle.arange(past_seen_tokens, past_seen_tokens + seq_len, dtype=paddle.int64).reshape([1, -1]) | ||
outputs = vl_gpt.language_model.llama( | ||
position_ids=position_ids, | ||
inputs_embeds=llm_emb, | ||
use_cache=True, | ||
attention_mask=attention_mask, | ||
past_key_values=past_key_values if step != 0 else None, | ||
return_dict=True, | ||
) | ||
if step == 0: | ||
past_key_values = [] | ||
for kv in outputs.past_key_values: | ||
# [2, 607, 16, 128] | ||
k, v = kv[0], kv[1] | ||
past_key_values.append((k[:, : inputs_embeds.shape[1], :, :], v[:, : inputs_embeds.shape[1], :, :])) | ||
past_key_values = tuple(past_key_values) | ||
hidden_states = outputs.last_hidden_state | ||
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :])) | ||
hidden_states = hidden_states.reshape([tuple(z_emb.shape)[0], 24, 24, 768]).transpose(perm=[0, 3, 1, 2]) | ||
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb) | ||
v_cond, v_uncond = paddle.chunk(x=v, chunks=2) | ||
v = cfg_weight * v_cond - (cfg_weight - 1.0) * v_uncond | ||
z = z + dt * v | ||
decoded_image = vae.decode(z / vae.config.scaling_factor).sample | ||
images = decoded_image.astype(dtype="float32").clip_(min=-1.0, max=1.0).transpose(perm=[0, 2, 3, 1]).cpu().numpy() | ||
images = ((images + 1) / 2.0 * 255).astype(np.uint8) | ||
|
||
os.makedirs("janusflow_generated_samples", exist_ok=True) | ||
for i in range(batch_size): | ||
save_path = os.path.join("janusflow_generated_samples", "img_{}.jpg".format(i)) | ||
PIL.Image.fromarray(images[i]).save(save_path) | ||
|
||
|
||
generate(vl_gpt, vl_chat_processor, tokenizer, prompt, num_inference_steps=args.inference_step) |
Oops, something went wrong.