Skip to content

Commit

Permalink
[PPMix No.12] Support Janus and JanusFlow inference (PaddlePaddle#839)
Browse files Browse the repository at this point in the history
Add examples of inference from unified multimodal understanding and
generation models, including Janus and JanusFlow.

---------

Co-authored-by: nemonameless@qq.com@github.com <nemonameless@qq.com>
  • Loading branch information
cheng221 and nemonameless authored Nov 26, 2024
1 parent e1e3620 commit 8bd1cba
Show file tree
Hide file tree
Showing 18 changed files with 3,850 additions and 1 deletion.
2 changes: 2 additions & 0 deletions paddlemix/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
"visualglm": "VisualGLMForConditionalGeneration",
"llava_qwen": "LlavaQwenForCausalLM",
"internvl2": "InternVLChatModel",
"janus":"JanusMultiModalityCausalLM",
"janus_flow":"JanusFlowMultiModality",
}


Expand Down
67 changes: 67 additions & 0 deletions paddlemix/examples/janus/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Janus/JanusFlow

## 1. 模型介绍

[Janus/JanusFlow](https://github.com/deepseek-ai/Janus) 将视觉编码解耦到单独的路径中,同时仍然使用单个统一的转换器架构进行处理,解决了以前方法的局限性。解耦不仅缓解了视觉编码器在理解和生成中的角色冲突,还增强了框架的灵活性。


**本仓库支持的模型权重:**

| Model |
|--------------------|
| deepseek-ai/Janus-1.3B |
| deepseek-ai/JanusFlow-1.3B |

注意:与huggingface权重同名,但权重为paddle框架的Tensor,使用`xxx.from_pretrained("deepseek-ai/Janus-1.3B")`即可自动下载该权重文件夹到缓存目录。


## 2 环境准备

1)[安装 PaddleMIX 环境依赖包](https://github.com/PaddlePaddle/PaddleMIX/tree/b4f97ff859e1964c839fc5fab94f7ba63b1e5959?tab=readme-ov-file#%E5%AE%89%E8%A3%85)

2) pip install pillow tqdm paddlenlp==3.0.0b2
注意:Python版本最好为3.10及以上版本。

## 3 快速开始

### 推理
```bash
# Janus understanding
python paddlemix/examples/janus/run_understanding_inference.py \
--model_path="deepseek-ai/Janus-1.3B" \
--image_file="paddlemix/demo_images/examples_image1.jpg" \
--question="What is shown in this image?" \

# Janus generation
python paddlemix/examples/janus/run_generation_inference.py \
--model_path="deepseek-ai/Janus-1.3B" \
--prompt="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair"

# JanusFlow generation
python paddlemix/examples/janus/run_generation_inference_janusflow.py \
--model_path="deepseek-ai/JanusFlow-1.3B" \
--inference_step=30 \
--prompt="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair"

# Janus interactivechat
python paddlemix/examples/janus/run_interactivechat.py \
--model_path="deepseek-ai/Janus-1.3B" \

```

### 参考文献
```BibTeX
@article{wu2024janus,
title={Janus: Decoupling visual encoding for unified multimodal understanding and generation},
author={Wu, Chengyue and Chen, Xiaokang and Wu, Zhiyu and Ma, Yiyang and Liu, Xingchao and Pan, Zizheng and Liu, Wen and Xie, Zhenda and Yu, Xingkai and Ruan, Chong and others},
journal={arXiv preprint arXiv:2410.13848},
year={2024}
}
@misc{ma2024janusflow,
title={JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation},
author={Yiyang Ma and Xingchao Liu and Xiaokang Chen and Wen Liu and Chengyue Wu and Zhiyu Wu and Zizheng Pan and Zhenda Xie and Haowei Zhang and Xingkai yu and Liang Zhao and Yisong Wang and Jiaying Liu and Chong Ruan},
journal={arXiv preprint arXiv:2411.07975},
year={2024}
}
```
3 changes: 3 additions & 0 deletions paddlemix/examples/janus/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pillow
tqdm
paddlenlp==3.0.0b2
121 changes: 121 additions & 0 deletions paddlemix/examples/janus/run_generation_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import numpy as np
import paddle
import PIL.Image
from paddlenlp.transformers import LlamaTokenizerFast
from tqdm import tqdm

from paddlemix.models.janus import JanusMultiModalityCausalLM
from paddlemix.processors import JanusImageProcessor, JanusVLChatProcessor

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="deepseek-ai/Janus-1.3B")
parser.add_argument(
"--prompt",
type=str,
default="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
)
args = parser.parse_args()

vl_gpt = JanusMultiModalityCausalLM.from_pretrained(args.model_path)
tokenizer = LlamaTokenizerFast.from_pretrained(args.model_path)
image_processer = JanusImageProcessor.from_pretrained(args.model_path)
vl_chat_processor: JanusVLChatProcessor = JanusVLChatProcessor(image_processer, tokenizer)

conversation = [
{
"role": "User",
"content": args.prompt,
},
{"role": "Assistant", "content": ""},
]
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=conversation, sft_format=vl_chat_processor.sft_format, system_prompt=""
)
prompt = sft_format + vl_chat_processor.image_start_tag


@paddle.no_grad()
def generate(
mmgpt,
vl_chat_processor,
prompt: str,
temperature: float = 1,
parallel_size: int = 2,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
img_size: int = 384,
patch_size: int = 16,
):
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = paddle.to_tensor(data=input_ids.input_ids, dtype="int64")
tokens = paddle.zeros(shape=(parallel_size * 2, len(input_ids)), dtype="int32")
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) # [4, 50, 2048]
generated_tokens = paddle.zeros(shape=(parallel_size, image_token_num_per_image), dtype="int32")
batch_size, seq_length = inputs_embeds.shape[:2]
for i in tqdm(range(image_token_num_per_image)):
batch_size, seq_length = inputs_embeds.shape[:2]

past_key_values_length = outputs.past_key_values[0][0].shape[1] if i != 0 else 0
position_ids = paddle.arange(past_key_values_length, seq_length + past_key_values_length).expand(
(batch_size, seq_length)
)

outputs = mmgpt.language_model.llama(
position_ids=position_ids,
inputs_embeds=inputs_embeds, # [4, 1, 2048]
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
return_dict=True,
)

hidden_states = outputs.last_hidden_state
logits = mmgpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]

logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = paddle.nn.functional.softmax(x=logits / temperature, axis=-1)
next_token = paddle.multinomial(x=probs, num_samples=1)

generated_tokens[:, i] = next_token.squeeze(axis=-1)
next_token = paddle.concat(x=[next_token.unsqueeze(axis=1), next_token.unsqueeze(axis=1)], axis=1).reshape(
[-1]
)
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(axis=1)

dec = mmgpt.gen_vision_model.decode_code(
generated_tokens.to(dtype="int32"), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
)
dec = dec.to("float32").cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
os.makedirs("janus_generated_samples", exist_ok=True)
for i in range(parallel_size):
save_path = os.path.join("janus_generated_samples", "img_{}.jpg".format(i))
PIL.Image.fromarray(visual_img[i]).save(save_path)


generate(vl_gpt, vl_chat_processor, prompt)
128 changes: 128 additions & 0 deletions paddlemix/examples/janus/run_generation_inference_janusflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import numpy as np
import paddle
import PIL.Image
from paddlenlp.transformers import LlamaTokenizerFast
from tqdm import tqdm

from paddlemix.models.janus import JanusFlowMultiModalityCausalLM
from paddlemix.processors import JanusImageProcessor, JanusVLChatProcessor
from ppdiffusers.models import AutoencoderKL

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="deepseek-ai/JanusFlow-1.3B")
parser.add_argument(
"--prompt",
type=str,
default="A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
)
parser.add_argument("--inference_step", type=int, default=30)

args = parser.parse_args()

vl_gpt = JanusFlowMultiModalityCausalLM.from_pretrained(args.model_path)
tokenizer = LlamaTokenizerFast.from_pretrained(args.model_path)
image_processer = JanusImageProcessor.from_pretrained(args.model_path)
vl_chat_processor: JanusVLChatProcessor = JanusVLChatProcessor(image_processer, tokenizer)
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")

conversation = [
{
"role": "User",
"content": args.prompt,
},
{"role": "Assistant", "content": ""},
]
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=conversation, sft_format=vl_chat_processor.sft_format, system_prompt=""
)
prompt = sft_format + vl_chat_processor.image_start_tag


@paddle.no_grad()
def generate(
vl_gpt,
vl_chat_processor,
tokenizer,
prompt,
cfg_weight: float = 2.0,
num_inference_steps: int = 30,
batch_size: int = 1,
):
input_ids = tokenizer(prompt, return_tensors="pd")["input_ids"]
tokens = paddle.stack(x=[input_ids] * batch_size * 2)[:, 0, :]
tokens[batch_size:, 1:] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
inputs_embeds = inputs_embeds[:, :-1, :]
z = paddle.randn(shape=(batch_size, 4, 48, 48), dtype="bfloat16")
dt = 1.0 / num_inference_steps
dt = paddle.zeros_like(x=z, dtype="bfloat16") + dt
attention_mask = paddle.ones(shape=(2 * batch_size, tuple(inputs_embeds.shape)[1] + 577))
attention_mask[batch_size:, 1 : tuple(inputs_embeds.shape)[1]] = 0
attention_mask = attention_mask.astype(dtype="int32")

for step in tqdm(range(num_inference_steps)):
z_input = paddle.concat(x=[z, z], axis=0)
t = step / num_inference_steps * 1000.0
t = paddle.to_tensor(data=[t] * tuple(z_input.shape)[0]).to(dt.place)
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
z_emb = z_emb.reshape([tuple(z_emb.shape)[0], tuple(z_emb.shape)[1], -1]).transpose(perm=[0, 2, 1])
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
llm_emb = (
paddle.concat(x=[inputs_embeds, t_emb.unsqueeze(axis=1), z_emb], axis=1)
if step == 0
else paddle.concat(x=[t_emb.unsqueeze(axis=1), z_emb], axis=1)
)
bs, seq_len, dim = llm_emb.shape
past_seen_tokens = inputs_embeds.shape[1] if step != 0 else 0
position_ids = paddle.arange(past_seen_tokens, past_seen_tokens + seq_len, dtype=paddle.int64).reshape([1, -1])
outputs = vl_gpt.language_model.llama(
position_ids=position_ids,
inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values if step != 0 else None,
return_dict=True,
)
if step == 0:
past_key_values = []
for kv in outputs.past_key_values:
# [2, 607, 16, 128]
k, v = kv[0], kv[1]
past_key_values.append((k[:, : inputs_embeds.shape[1], :, :], v[:, : inputs_embeds.shape[1], :, :]))
past_key_values = tuple(past_key_values)
hidden_states = outputs.last_hidden_state
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
hidden_states = hidden_states.reshape([tuple(z_emb.shape)[0], 24, 24, 768]).transpose(perm=[0, 3, 1, 2])
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
v_cond, v_uncond = paddle.chunk(x=v, chunks=2)
v = cfg_weight * v_cond - (cfg_weight - 1.0) * v_uncond
z = z + dt * v
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
images = decoded_image.astype(dtype="float32").clip_(min=-1.0, max=1.0).transpose(perm=[0, 2, 3, 1]).cpu().numpy()
images = ((images + 1) / 2.0 * 255).astype(np.uint8)

os.makedirs("janusflow_generated_samples", exist_ok=True)
for i in range(batch_size):
save_path = os.path.join("janusflow_generated_samples", "img_{}.jpg".format(i))
PIL.Image.fromarray(images[i]).save(save_path)


generate(vl_gpt, vl_chat_processor, tokenizer, prompt, num_inference_steps=args.inference_step)
Loading

0 comments on commit 8bd1cba

Please sign in to comment.