Skip to content

Commit

Permalink
update code for better quality (#105)
Browse files Browse the repository at this point in the history
* change bs for 2K

Signed-off-by: lawrence-cj <cjs1020440147@icloud.com>

* 1. code update for better quality;
2. fix app bug for CFG+PAG inference;
3. change default inference setting to CFG only;

Signed-off-by: lawrence-cj <cjs1020440147@icloud.com>

* update README.md;

Signed-off-by: lawrence-cj <cjs1020440147@icloud.com>

* change config name

Signed-off-by: lawrence-cj <cjs1020440147@icloud.com>

* update README.md;

Signed-off-by: lawrence-cj <cjs1020440147@icloud.com>

---------

Signed-off-by: lawrence-cj <cjs1020440147@icloud.com>
Co-authored-by: Enze Xie <johnny_ez@163.com>
  • Loading branch information
lawrence-cj and xieenze authored Dec 20, 2024
1 parent 374447b commit 32c94fe
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 215 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.

## 🔥🔥 News

- (🔥 New) \[2024/12/20\] 1.6B 2K resolution [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). 🚀 Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/).
- (🔥 New) \[2024/12/20\] 1.6B 2K resolution [Sana models](asset/docs/model_zoo.md) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). 🚀 Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/).
- (🔥 New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is supper fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).
- (🔥 New) \[2024/12/13\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose.
- (🔥 New) \[2024/12/10\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning.
Expand Down Expand Up @@ -126,7 +126,8 @@ DEMO_PORT=15432 \
python app/app_sana.py \
--share \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
--image_size=1024
```

### 1. How to use `SanaPipeline` with `🧨diffusers`
Expand Down
4 changes: 2 additions & 2 deletions app/app_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,14 @@ def generate(
minimum=1,
maximum=10,
step=0.1,
value=5.0,
value=4.5,
)
flow_dpms_pag_guidance_scale = gr.Slider(
label="PAG Guidance scale",
minimum=1,
maximum=4,
step=0.5,
value=2.0,
value=1.0,
)
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
Expand Down
51 changes: 14 additions & 37 deletions app/sana_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from diffusion import DPMS, FlowEuler
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

# from diffusion.utils.misc import read_config
Expand All @@ -40,6 +40,8 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
guidance_type = default_guidance_type
if not (pag_scale > 1.0 and attn_type == "linear"):
guidance_type = "classifier-free"
elif pag_scale > 1.0 and attn_type == "linear":
guidance_type = "classifier-free_PAG"
return guidance_type


Expand Down Expand Up @@ -93,15 +95,9 @@ def __init__(
self.flow_shift = config.scheduler.flow_shift
guidance_type = "classifier-free_PAG"

if config.model.mixed_precision == "fp16":
weight_dtype = torch.float16
elif config.model.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif config.model.mixed_precision == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
weight_dtype = get_weight_dtype(config.model.mixed_precision)
self.weight_dtype = weight_dtype
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)

self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
self.vis_sampler = self.config.scheduler.vis_sampler
Expand All @@ -126,7 +122,7 @@ def __init__(
]

def build_vae(self, config):
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype)
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
return vae

def build_text_encoder(self, config):
Expand All @@ -135,31 +131,12 @@ def build_text_encoder(self, config):

def build_sana_model(self, config):
# model setting
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
model_kwargs = {
"input_size": self.latent_size,
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": self.text_encoder.config.hidden_size,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
"use_fp32_attention": config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
}
model = build_model(config.model.model, **model_kwargs)
model = model.to(self.weight_dtype)

model_kwargs = model_init_config(config, latent_size=self.latent_size)
model = build_model(
config.model.model,
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
**model_kwargs,
)
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
self.logger.info(
f"{model.__class__.__name__}:{config.model.model},"
Expand Down Expand Up @@ -310,7 +287,7 @@ def forward(
flow_shift=self.flow_shift,
)

sample = sample.to(self.weight_dtype)
sample = sample.to(self.vae_dtype)
with torch.no_grad():
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ scheduler:
train:
num_workers: 10
seed: 1
train_batch_size: 64
train_batch_size: 4
num_epochs: 100
gradient_accumulation_steps: 1
grad_checkpointing: true
Expand Down
11 changes: 11 additions & 0 deletions diffusion/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,14 @@ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, .
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2


def get_weight_dtype(mixed_precision):
if mixed_precision in ["fp16", "float16"]:
return torch.float16
elif mixed_precision in ["bf16", "bfloat16"]:
return torch.bfloat16
elif mixed_precision in ["fp32", "float32"]:
return torch.float32
else:
raise ValueError(f"weigh precision {mixed_precision} is not defined")
28 changes: 28 additions & 0 deletions diffusion/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ModelConfig(BaseConfig):
class AEConfig(BaseConfig):
vae_type: str = "dc-ae"
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0"
weight_dtype: str = "bfloat16"
scale_factor: float = 0.41407
vae_latent_dim: int = 32
vae_downsample_rate: int = 32
Expand Down Expand Up @@ -191,3 +192,30 @@ class SanaConfig(BaseConfig):
tracker_project_name: str = "t2i-evit-baseline"
name: str = "baseline"
loss_report_name: str = "loss"


def model_init_config(config: SanaConfig, latent_size: int = 32):

pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
return {
"input_size": latent_size,
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": config.text_encoder.caption_channels,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
"cross_norm": config.model.cross_norm,
}
44 changes: 10 additions & 34 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from diffusion import DPMS, FlowEuler, SASolverSampler
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import prepare_prompt_ar
from diffusion.utils.config import SanaConfig
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger
from tools.download import find_model

Expand Down Expand Up @@ -209,15 +209,14 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale
else:
raise ValueError(f"{args.sampling_algo} is not defined")

samples = samples.to(weight_dtype)
samples = samples.to(vae_dtype)
samples = vae_decode(config.vae.vae_type, vae, samples)
torch.cuda.empty_cache()

os.umask(0o000)
for i, sample in enumerate(samples):
save_file_name = f"{chunk[i]}.jpg" if dict_prompt else f"{prompts[i][:100]}.jpg"
save_path = os.path.join(save_root, save_file_name)
# logger.info(f"Saving path: {save_path}")
save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))


Expand Down Expand Up @@ -287,17 +286,12 @@ class SanaInference(SanaConfig):
args.interval_guidance = [max(0, args.interval_guidance[0]), min(1, args.interval_guidance[1])]
sample_steps_dict = {"dpm-solver": 20, "sa-solver": 25, "flow_dpm-solver": 20, "flow_euler": 28}
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
if config.model.mixed_precision == "fp16":
weight_dtype = torch.float16
elif config.model.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif config.model.mixed_precision == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")

weight_dtype = get_weight_dtype(config.model.mixed_precision)
logger.info(f"Inference with {weight_dtype}, default guidance_type: {guidance_type}, flow_shift: {flow_shift}")

vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype)
vae_dtype = get_weight_dtype(config.vae.weight_dtype)
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(vae_dtype)
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device)

null_caption_token = tokenizer(
Expand All @@ -306,27 +300,7 @@ class SanaInference(SanaConfig):
null_caption_embs = text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]

# model setting
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
model_kwargs = {
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": text_encoder.config.hidden_size,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
}
model_kwargs = model_init_config(config, latent_size=latent_size)
model = build_model(
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs
).to(device)
Expand Down Expand Up @@ -418,6 +392,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand All @@ -441,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand Down
41 changes: 7 additions & 34 deletions scripts/inference_dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
get_chunks,
)
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import prepare_prompt_ar
from diffusion.utils.config import SanaConfig
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

# from diffusion.utils.misc import read_config
Expand Down Expand Up @@ -195,7 +195,7 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
else:
raise ValueError(f"{args.sampling_algo} is not defined")

samples = samples.to(weight_dtype)
samples = samples.to(vae_dtype)
samples = vae_decode(config.vae.vae_type, vae, samples)
torch.cuda.empty_cache()

Expand Down Expand Up @@ -298,17 +298,11 @@ class SanaInference(SanaConfig):
args.interval_guidance = [max(0, args.interval_guidance[0]), min(1, args.interval_guidance[1])]
sample_steps_dict = {"dpm-solver": 20, "sa-solver": 25, "flow_dpm-solver": 20, "flow_euler": 28}
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
if config.model.mixed_precision == "fp16":
weight_dtype = torch.float16
elif config.model.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif config.model.mixed_precision == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
weight_dtype = get_weight_dtype(config.model.mixed_precision)
logger.info(f"Inference with {weight_dtype}, default guidance_type: {guidance_type}, flow_shift: {flow_shift}")

vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype)
vae_dtype = get_weight_dtype(config.vae.weight_dtype)
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(vae_dtype)
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device)

null_caption_token = tokenizer(
Expand All @@ -317,28 +311,7 @@ class SanaInference(SanaConfig):
null_caption_embs = text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]

# model setting
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
model_kwargs = {
"input_size": latent_size,
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": text_encoder.config.hidden_size,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
}
model_kwargs = model_init_config(config, latent_size=latent_size)
model = build_model(
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs
).to(device)
Expand Down
Loading

0 comments on commit 32c94fe

Please sign in to comment.