Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#50 from LokeZhou/autolable
Browse files Browse the repository at this point in the history
autolable update blip2
  • Loading branch information
LokeZhou authored Aug 9, 2023
2 parents 6b33637 + 7943a9a commit 6a75d16
Showing 1 changed file with 94 additions and 77 deletions.
171 changes: 94 additions & 77 deletions applications/Automatic_label/automatic_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
import paddle.nn.functional as F
from PIL import Image, ImageDraw, ImageFont


from paddlevlp.processors.groundingdino_processing import GroudingDinoProcessor
from paddlevlp.models.groundingdino.modeling import GroundingDinoModel
from paddlevlp.models.sam.modeling import SamModel
from paddlevlp.processors.sam_processing import SamProcessor
from paddlenlp.transformers import AutoTokenizer
from paddlevlp.processors.blip_processing import BlipImageProcessor, BlipTextProcessor
from paddlevlp.models.blip2.modeling import Blip2ForConditionalGeneration
from paddlevlp.processors.blip_processing import Blip2Processor
import nltk
Expand All @@ -42,7 +43,7 @@ def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
Expand All @@ -51,7 +52,9 @@ def show_mask(mask, ax, random_color=False):
def show_box(box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.add_patch(
plt.Rectangle(
(x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
ax.text(x0, y0, label)


Expand All @@ -64,10 +67,12 @@ class DataArguments:
the command line.
"""

input_image: str = field(
metadata={"help": "The name of input image."}
)

input_image: str = field(metadata={"help": "The name of input image."})

prompt: str = field(
default="describe the image",
metadata={"help": "The prompt of the image to be generated."
}) # "Question: how many cats are there? Answer:"


@dataclass
Expand All @@ -76,152 +81,162 @@ class ModelArguments:
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
blip2_model_name_or_path: str = field(
default="Salesforce/blip2-opt-2.7b",
metadata={"help": "Path to pretrained model or model identifier"},
)
default="paddlemix/blip2-caption-opt2.7b",
metadata={"help": "Path to pretrained model or model identifier"}, )
text_model_name_or_path: str = field(
default="facebook/opt-2.7b",
metadata={"help": "The type of text model to use (OPT, T5)."}, )
dino_model_name_or_path: str = field(
default="GroundingDino/groundingdino-swint-ogc",
metadata={"help": "Path to pretrained model or model identifier"},
)
metadata={"help": "Path to pretrained model or model identifier"}, )
sam_model_name_or_path: str = field(
default="Sam/SamVitH-1024",
metadata={"help": "Path to pretrained model or model identifier"},
)
metadata={"help": "Path to pretrained model or model identifier"}, )
box_threshold: float = field(
default=0.3,
metadata={
"help": "box threshold."
},
)
metadata={"help": "box threshold."}, )
text_threshold: float = field(
default=0.25,
metadata={
"help": "text threshold."
},
)
metadata={"help": "text threshold."}, )
output_dir: str = field(
default="automatic_label",
metadata={
"help": "output directory."
},
)
metadata={"help": "output directory."}, )
visual: bool = field(
default=True,
metadata={
"help": "save visual image."
},
)
metadata={"help": "save visual image."}, )


def generate_caption(raw_image, prompt, processor, blip2_model):

def generate_caption(raw_image, processor,blip2_model):

inputs = processor(
images=raw_image,
text=None,
text=prompt,
return_tensors="pd",
return_attention_mask=True,
mode="test",
)
mode="test", )
generated_ids, scores = blip2_model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
0
].strip()
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)[0].strip()
logger.info("Generate text: {}".format(generated_text))

return generated_text


def generate_tags(caption):
lemma = nltk.wordnet.WordNetLemmatizer()

nltk.download(['punkt', 'averaged_perceptron_tagger', 'wordnet'])
tags_list = [word for (word, pos) in nltk.pos_tag(nltk.word_tokenize(caption)) if pos[0] == 'N']
tags_list = [
word for (word, pos) in nltk.pos_tag(nltk.word_tokenize(caption))
if pos[0] == 'N'
]
tags_lemma = [lemma.lemmatize(w) for w in tags_list]
tags = ', '.join(map(str, tags_lemma))

return tags


def main():
parser = PdArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()
url = (data_args.input_image)

logger.info("blip2_model: {}".format(model_args.blip2_model_name_or_path))
#bulid blip2 processor
blip2_processor = Blip2Processor.from_pretrained(
model_args.blip2_model_name_or_path
) # "Salesforce/blip2-opt-2.7b"
#bulid blip2 model
blip2_model = Blip2ForConditionalGeneration.from_pretrained(model_args.blip2_model_name_or_path)

blip2_tokenizer_class = AutoTokenizer.from_pretrained(
model_args.text_model_name_or_path, use_fast=False)
blip2_image_processor = BlipImageProcessor.from_pretrained(
os.path.join(model_args.blip2_model_name_or_path, "processor", "eval"))
blip2_text_processor_class = BlipTextProcessor.from_pretrained(
os.path.join(model_args.blip2_model_name_or_path, "processor", "eval"))
blip2_processor = Blip2Processor(blip2_image_processor,
blip2_text_processor_class,
blip2_tokenizer_class)

# #bulid blip2 model
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
model_args.blip2_model_name_or_path)
paddle.device.cuda.empty_cache()
blip2_model.eval()
blip2_model.to("gpu")

logger.info("blip2_model build finish!")

logger.info("dino_model: {}".format(model_args.dino_model_name_or_path))
#bulid dino processor
dino_processor = GroudingDinoProcessor.from_pretrained(
model_args.dino_model_name_or_path
)
model_args.dino_model_name_or_path)
#bulid dino model
dino_model = GroundingDinoModel.from_pretrained(model_args.dino_model_name_or_path)
dino_model = GroundingDinoModel.from_pretrained(
model_args.dino_model_name_or_path)
dino_model.eval()
logger.info("dino_model build finish!")

#buidl sam processor
sam_processor = SamProcessor.from_pretrained(
model_args.sam_model_name_or_path
)
model_args.sam_model_name_or_path)
#bulid model
logger.info("SamModel: {}".format(model_args.sam_model_name_or_path))
sam_model = SamModel.from_pretrained(model_args.sam_model_name_or_path,input_type="boxs")
sam_model = SamModel.from_pretrained(
model_args.sam_model_name_or_path, input_type="boxs")
logger.info("SamModel build finish!")

#read image
if os.path.isfile(url):
#read image
image_pil = Image.open(data_args.input_image)
else:
image_pil = Image.open(requests.get(url, stream=True).raw)

caption = generate_caption(image_pil,processor=blip2_processor,blip2_model=blip2_model)
prompt = generate_tags(caption)
logger.info("prompt: {}".format(prompt))

caption = generate_caption(
image_pil,
prompt=data_args.prompt,
processor=blip2_processor,
blip2_model=blip2_model)

det_prompt = generate_tags(caption)
logger.info("det prompt: {}".format(det_prompt))

image_pil = image_pil.convert("RGB")

#preprocess image text_prompt
image_tensor,mask,tokenized_out = dino_processor(images=image_pil,text=prompt)
image_tensor, mask, tokenized_out = dino_processor(
images=image_pil, text=det_prompt)

with paddle.no_grad():
outputs = dino_model(image_tensor,mask, input_ids=tokenized_out['input_ids'],
attention_mask=tokenized_out['attention_mask'],text_self_attention_masks=tokenized_out['text_self_attention_masks'],
position_ids=tokenized_out['position_ids'])
outputs = dino_model(
image_tensor,
mask,
input_ids=tokenized_out['input_ids'],
attention_mask=tokenized_out['attention_mask'],
text_self_attention_masks=tokenized_out[
'text_self_attention_masks'],
position_ids=tokenized_out['position_ids'])

logits = F.sigmoid(outputs["pred_logits"])[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)

# filter output
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(axis=1) > model_args.box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4

# build pred
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = dino_processor.decode(logit > model_args.text_threshold)
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")


size = image_pil.size
pred_dict = {
"boxes": boxes_filt,
"size": [size[1], size[0]], # H,W
"labels": pred_phrases,
}
logger.info("dino output{}".format(pred_dict))
H,W = size[1], size[0]

H, W = size[1], size[0]
boxes = []
for box in zip(boxes_filt):
box = box[0] * paddle.to_tensor([W, H, W, H])
Expand All @@ -231,12 +246,13 @@ def main():
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
boxes.append([x0, y0, x1, y1])
boxes = np.array(boxes)
image_seg,prompt = sam_processor(image_pil,input_type="boxs",box=boxes,point_coords=None)
seg_masks = sam_model(img=image_seg,prompt=prompt)
image_seg, prompt = sam_processor(
image_pil, input_type="boxs", box=boxes, point_coords=None)
seg_masks = sam_model(img=image_seg, prompt=prompt)
seg_masks = sam_processor.postprocess_masks(seg_masks)

logger.info("Sam finish!")

if model_args.visual:
# make dir
os.makedirs(model_args.output_dir, exist_ok=True)
Expand All @@ -247,16 +263,17 @@ def main():
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes, pred_phrases):
show_box(box, plt.gca(), label)

plt.title(caption)
plt.axis('off')
plt.savefig(
os.path.join(model_args.output_dir, 'mask_pred.jpg'),
bbox_inches="tight", dpi=300, pad_inches=0.0
)
os.path.join(model_args.output_dir, 'mask_pred.jpg'),
bbox_inches="tight",
dpi=300,
pad_inches=0.0)

logger.info("finish!")


if __name__ == "__main__":
main()
main()

0 comments on commit 6a75d16

Please sign in to comment.