Skip to content

Commit

Permalink
Merge pull request #302 from Anhforth/add_adm18
Browse files Browse the repository at this point in the history
Add adm18
  • Loading branch information
ftgreat authored Apr 14, 2023
2 parents 0aae95c + 554a0bb commit 0834cbf
Show file tree
Hide file tree
Showing 28 changed files with 7,908 additions and 131 deletions.
393 changes: 282 additions & 111 deletions examples/AltCLIP-m18/README.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion examples/AltCLIP-m18/altclip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

model = loader.get_model()

tokenizer = loader.get_tokenizer()
transform = loader.get_transform()

Expand All @@ -19,7 +20,7 @@
tokenizer = loader.get_tokenizer()

def inference():
image = Image.open("./dog.jpeg")
image = Image.open("./examples/AltCLIP-m18//dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
Expand Down
2 changes: 1 addition & 1 deletion examples/AltCLIP/altclip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
tokenizer = loader.get_tokenizer()

def inference():
image = Image.open("./dog.jpeg")
image = Image.open("./examples/AltCLIP/dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
Expand Down
24 changes: 24 additions & 0 deletions examples/AltDiffusion-m18/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor

# Initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loader = AutoLoader(task_name="text2img", #contrastive learning
model_name="AltDiffusion-m18",
model_dir="./checkpoints",
use_fp16=False)
model = loader.get_model()
model.eval()
model.to(device)
predictor = Predictor(model)
prompt = "Daenerys Targaryen as a mermeid with a piercing gaze wearing an enchanted bikini in an underwater magical forest, highly detailed face, realistic face, beautiful detailed eyes, fantasy art, in the style of artgerm, illustration, epic, fantasy, intricate, hyper detailed, artstation, concept art, smooth, sharp focus, ray tracing, vibrant, photorealistic"
negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, extra head, extra legs,fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
seed = 553124
predictor.predict_generate_images(
prompt=prompt,negative_prompt=negative_prompt,seed=seed
)
2 changes: 1 addition & 1 deletion examples/AltDiffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
predictor = Predictor(model)
predictor.predict_generate_images(
"Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation"
)
)
9 changes: 5 additions & 4 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __getattr__(self, name):
"cpm3_lm": ("flagai.model.cpm3_model", "CPM3"),
"cpm3_train": ("flagai.model.cpm3_train_model", "CPM3"),
"diffusion_text2img": ("flagai.model.mm.AltDiffusion", "LatentDiffusion"),
"diffusion_m18_text2img": ("flagai.model.mm.AltDiffusionM18", "LatentDiffusion"),
"altclip_txt_img_matching": ("flagai.model.mm.AltCLIP", "AltCLIP"),
"evaclip_txt_img_matching": ("flagai.model.mm.eva_clip_model", "EVA_CLIP"),
}
Expand Down Expand Up @@ -121,7 +122,9 @@ def __getattr__(self, name):
"altdiffusion":
["flagai.model.mm.diffusion", "LatentDiffusion", "diffusion", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
"altdiffusion-m9":
["flagai.model.mm.diffusion", "LatentDiffusion", "diffusion", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
["flagai.model.mm.diffusionM18", "LatentDiffusion", "diffusion", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
"altdiffusion-m18":
["flagai.model.mm.AltdiffusionM18", "LatentDiffusion", "diffusion_m18", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
"swinv1-base-patch4-window7-224":
["flagai.model.vision.swinv1", "SwinTransformer", "swinv1", "vision"],
"swinv2-base-patch4-window8-256":
Expand Down Expand Up @@ -200,7 +203,6 @@ def __init__(self,
f"For the model_name: {model_name}, these tasks are be supported: {tasks}"
)
return

download_path = os.path.join(model_dir, raw_model_name)
print("*" * 20, task_name, model_name)
model_name_ = self.is_exist_finetuned_model(raw_model_name, task_name)
Expand All @@ -213,7 +215,7 @@ def __init__(self,
**kwargs)
if kwargs.get("use_fp16", None):
self.model.half()

if model_type == "nlp":
if brief_model_name in ["galactica", ]:
self.tokenizer = getattr(LazyImport(MODEL_DICT[model_name][4]),
Expand Down Expand Up @@ -254,7 +256,6 @@ def is_exist_finetuned_model(self, raw_model_name, task_name):
return model_name_
else :
return raw_model_name

except:
print("Model hub is not reachable.")
return raw_model_name
Expand Down
8 changes: 4 additions & 4 deletions flagai/model/mm/AltCLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def __init__(self,
if text_config_dict is None:
text_config_dict = {}
# when reload the config from local, we need name to select which class should be instanced.
self.text_config = STUDENT_CONFIG_DICT[
kwargs['text_config']['model_type']](**kwargs.pop('text_config'))
self.text_config = STUDENT_CONFIG_DICT[kwargs['text_config']['model_type']](**kwargs.pop('text_config'))
self.num_layers = num_layers
self.text_model_name = text_model_name
self.vision_model_name = vision_model_name
Expand All @@ -98,7 +97,6 @@ def __init__(self, config: AltCLIPConfig, clip_model=None):
raise ValueError(
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
f" {type(config.vision_config)}.")

text_config = config.text_config
vision_config = config.vision_config

Expand Down Expand Up @@ -436,7 +434,7 @@ def forward(
class AltCLIP(BaseModel):

def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(config=config, **kwargs)

@classmethod
def from_pretrain(cls,
Expand All @@ -448,4 +446,6 @@ def from_pretrain(cls,
super().download(download_path, model_name, only_download_config=only_download_config)
pretrained_model_name_or_path = os.path.join(download_path, model_name)
print(pretrained_model_name_or_path)
print("Downloading AltCLIP")

return CLIPHF.from_pretrained(pretrained_model_name_or_path)
2 changes: 1 addition & 1 deletion flagai/model/mm/AltDiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,4 +1932,4 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
]

return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))
((mean1 - mean2)**2) * torch.exp(-logvar2))
Loading

0 comments on commit 0834cbf

Please sign in to comment.