Skip to content

Commit

Permalink
some modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Coobiw committed Mar 14, 2024
1 parent 9a049e1 commit 8857382
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
11 changes: 9 additions & 2 deletions cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def _load_model_processor(args):
model, vis_processors, _ = load_model_and_preprocess("minigpt4qwen", args.model_type)
model.load_checkpoint(args.checkpoint_path)

model.llm_model.transformer.bfloat16()
model.llm_model.lm_head.bfloat16()

generation_config = {
"chat_format": "chatml",
"eos_token_id": 151643,
Expand Down Expand Up @@ -233,8 +236,12 @@ def main():
if '<ImageHere>' not in query:
query = '<Img><ImageHere></Img> ' + query
first = False
with torch.autocast(device_type="cpu",enabled=True,dtype=torch.bfloat16) if args.cpu_only else torch.cuda.amp.autocast(enabled=True,dtype=torch.bfloat16):
response, history = model.chat(query, history=history, image_tensor=image_tensor, generation_config=generation_config)
if args.cpu_only:
model.bfloat16()
response, history = model.chat(query, history=history, image_tensor=image_tensor.bfloat16(), generation_config=generation_config)
else:
with torch.cuda.amp.autocast(enabled=True,dtype=torch.bfloat16):
response, history = model.chat(query, history=history, image_tensor=image_tensor, generation_config=generation_config)
_clear_screen()
print(f"\nUser: {query}")
print(f"\nQwen-Chat: {response}")
Expand Down
6 changes: 3 additions & 3 deletions lavis/projects/pp_qwen14b/sft_pp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
lora_dropout: 0.05

# text length when training
max_txt_len: 256
max_txt_len: 512

# enable autocast of vit
enable_autocast: False
Expand Down Expand Up @@ -77,14 +77,14 @@ run:
dist_url: "env://"
distributed: True

max_epoch: 2
max_epoch: 1
log_freq: 10

lr_sched: "linear_warmup_cosine_lr_step-wise"
warmup_lr: 0
init_lr: 2e-5
min_lr: 0
warmup_ratio: 0.3
warmup_ratio: 0.1

deepspeed_config:
# global batch = 128 = n_ranks * grad_acc_steps * micro_batch_size = (4//2) * 64 * 1
Expand Down
13 changes: 10 additions & 3 deletions webui_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def _load_model_processor(args):
model, vis_processors, _ = load_model_and_preprocess("minigpt4qwen", args.model_type)
model.load_checkpoint(args.checkpoint_path)

model.llm_model.transformer.bfloat16()
model.llm_model.lm_head.bfloat16()

generation_config = {
"chat_format": "chatml",
"eos_token_id": 151643,
Expand Down Expand Up @@ -148,8 +151,12 @@ def gradio_answer(chatbot, history, img_list, do_sample,num_beams, temperature,
image_tensor = img_list[0] # 如果想支持多图情况:torch.stack(img_list).to(self.device)
generation_config = GenerationConfig.from_dict(generation_config)
global args
with torch.autocast(device_type="cpu",enabled=True,dtype=torch.bfloat16) if args.cpu_only else torch.cuda.amp.autocast(enabled=True,dtype=torch.bfloat16):
response, history = model.chat(query=chatbot[-1][0], history=history, image_tensor=image_tensor, generation_config=generation_config,verbose=True)
if args.cpu_only:
model.bfloat16()
response, history = model.chat(query=chatbot[-1][0], history=history, image_tensor=image_tensor.bfloat16(), generation_config=generation_config,verbose=True)
else:
with torch.cuda.amp.autocast(enabled=True,dtype=torch.bfloat16):
response, history = model.chat(query=chatbot[-1][0], history=history, image_tensor=image_tensor.bfloat16(), generation_config=generation_config,verbose=True)
chatbot[-1][1] = response
return chatbot, history, img_list

Expand Down Expand Up @@ -230,4 +237,4 @@ def gradio_answer(chatbot, history, img_list, do_sample,num_beams, temperature,
)
clear.click(gradio_reset, [history, img_list], [chatbot, image, text_input, upload_button, history, img_list], queue=False)

demo.launch(share=True)
demo.launch(share=True,inbrowser=True)

0 comments on commit 8857382

Please sign in to comment.