Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update code for better quality #105

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.

## 🔥🔥 News

- (🔥 New) \[2024/12/20\] 1.6B 2K resolution [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). 🚀 Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/).
- (🔥 New) \[2024/12/20\] 1.6B 2K resolution [Sana models](asset/docs/model_zoo.md) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). 🚀 Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/).
- (🔥 New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is supper fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).
- (🔥 New) \[2024/12/13\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose.
- (🔥 New) \[2024/12/10\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning.
Expand Down Expand Up @@ -126,7 +126,8 @@ DEMO_PORT=15432 \
python app/app_sana.py \
--share \
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
--image_size=1024
```

### 1. How to use `SanaPipeline` with `🧨diffusers`
Expand Down
4 changes: 2 additions & 2 deletions app/app_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,14 @@ def generate(
minimum=1,
maximum=10,
step=0.1,
value=5.0,
value=4.5,
)
flow_dpms_pag_guidance_scale = gr.Slider(
label="PAG Guidance scale",
minimum=1,
maximum=4,
step=0.5,
value=2.0,
value=1.0,
)
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
Expand Down
51 changes: 14 additions & 37 deletions app/sana_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from diffusion import DPMS, FlowEuler
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

# from diffusion.utils.misc import read_config
Expand All @@ -40,6 +40,8 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
guidance_type = default_guidance_type
if not (pag_scale > 1.0 and attn_type == "linear"):
guidance_type = "classifier-free"
elif pag_scale > 1.0 and attn_type == "linear":
guidance_type = "classifier-free_PAG"
return guidance_type


Expand Down Expand Up @@ -93,15 +95,9 @@ def __init__(
self.flow_shift = config.scheduler.flow_shift
guidance_type = "classifier-free_PAG"

if config.model.mixed_precision == "fp16":
weight_dtype = torch.float16
elif config.model.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif config.model.mixed_precision == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
weight_dtype = get_weight_dtype(config.model.mixed_precision)
self.weight_dtype = weight_dtype
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)

self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
self.vis_sampler = self.config.scheduler.vis_sampler
Expand All @@ -126,7 +122,7 @@ def __init__(
]

def build_vae(self, config):
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype)
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
return vae

def build_text_encoder(self, config):
Expand All @@ -135,31 +131,12 @@ def build_text_encoder(self, config):

def build_sana_model(self, config):
# model setting
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
model_kwargs = {
"input_size": self.latent_size,
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": self.text_encoder.config.hidden_size,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
"use_fp32_attention": config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
}
model = build_model(config.model.model, **model_kwargs)
model = model.to(self.weight_dtype)

model_kwargs = model_init_config(config, latent_size=self.latent_size)
model = build_model(
config.model.model,
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
**model_kwargs,
)
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
self.logger.info(
f"{model.__class__.__name__}:{config.model.model},"
Expand Down Expand Up @@ -310,7 +287,7 @@ def forward(
flow_shift=self.flow_shift,
)

sample = sample.to(self.weight_dtype)
sample = sample.to(self.vae_dtype)
with torch.no_grad():
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ scheduler:
train:
num_workers: 10
seed: 1
train_batch_size: 64
train_batch_size: 4
num_epochs: 100
gradient_accumulation_steps: 1
grad_checkpointing: true
Expand Down
11 changes: 11 additions & 0 deletions diffusion/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,14 @@ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, .
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2


def get_weight_dtype(mixed_precision):
if mixed_precision in ["fp16", "float16"]:
return torch.float16
elif mixed_precision in ["bf16", "bfloat16"]:
return torch.bfloat16
elif mixed_precision in ["fp32", "float32"]:
return torch.float32
else:
raise ValueError(f"weigh precision {mixed_precision} is not defined")
28 changes: 28 additions & 0 deletions diffusion/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ModelConfig(BaseConfig):
class AEConfig(BaseConfig):
vae_type: str = "dc-ae"
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0"
weight_dtype: str = "bfloat16"
scale_factor: float = 0.41407
vae_latent_dim: int = 32
vae_downsample_rate: int = 32
Expand Down Expand Up @@ -191,3 +192,30 @@ class SanaConfig(BaseConfig):
tracker_project_name: str = "t2i-evit-baseline"
name: str = "baseline"
loss_report_name: str = "loss"


def model_init_config(config: SanaConfig, latent_size: int = 32):

pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
return {
"input_size": latent_size,
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": config.text_encoder.caption_channels,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
"cross_norm": config.model.cross_norm,
}
44 changes: 10 additions & 34 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from diffusion import DPMS, FlowEuler, SASolverSampler
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import prepare_prompt_ar
from diffusion.utils.config import SanaConfig
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger
from tools.download import find_model

Expand Down Expand Up @@ -209,15 +209,14 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale
else:
raise ValueError(f"{args.sampling_algo} is not defined")

samples = samples.to(weight_dtype)
samples = samples.to(vae_dtype)
samples = vae_decode(config.vae.vae_type, vae, samples)
torch.cuda.empty_cache()

os.umask(0o000)
for i, sample in enumerate(samples):
save_file_name = f"{chunk[i]}.jpg" if dict_prompt else f"{prompts[i][:100]}.jpg"
save_path = os.path.join(save_root, save_file_name)
# logger.info(f"Saving path: {save_path}")
save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))


Expand Down Expand Up @@ -287,17 +286,12 @@ class SanaInference(SanaConfig):
args.interval_guidance = [max(0, args.interval_guidance[0]), min(1, args.interval_guidance[1])]
sample_steps_dict = {"dpm-solver": 20, "sa-solver": 25, "flow_dpm-solver": 20, "flow_euler": 28}
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
if config.model.mixed_precision == "fp16":
weight_dtype = torch.float16
elif config.model.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif config.model.mixed_precision == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")

weight_dtype = get_weight_dtype(config.model.mixed_precision)
logger.info(f"Inference with {weight_dtype}, default guidance_type: {guidance_type}, flow_shift: {flow_shift}")

vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype)
vae_dtype = get_weight_dtype(config.vae.weight_dtype)
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(vae_dtype)
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device)

null_caption_token = tokenizer(
Expand All @@ -306,27 +300,7 @@ class SanaInference(SanaConfig):
null_caption_embs = text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]

# model setting
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
model_kwargs = {
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": text_encoder.config.hidden_size,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
}
model_kwargs = model_init_config(config, latent_size=latent_size)
model = build_model(
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs
).to(device)
Expand Down Expand Up @@ -418,6 +392,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand All @@ -441,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand Down
41 changes: 7 additions & 34 deletions scripts/inference_dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
get_chunks,
)
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import prepare_prompt_ar
from diffusion.utils.config import SanaConfig
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

# from diffusion.utils.misc import read_config
Expand Down Expand Up @@ -195,7 +195,7 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
else:
raise ValueError(f"{args.sampling_algo} is not defined")

samples = samples.to(weight_dtype)
samples = samples.to(vae_dtype)
samples = vae_decode(config.vae.vae_type, vae, samples)
torch.cuda.empty_cache()

Expand Down Expand Up @@ -298,17 +298,11 @@ class SanaInference(SanaConfig):
args.interval_guidance = [max(0, args.interval_guidance[0]), min(1, args.interval_guidance[1])]
sample_steps_dict = {"dpm-solver": 20, "sa-solver": 25, "flow_dpm-solver": 20, "flow_euler": 28}
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
if config.model.mixed_precision == "fp16":
weight_dtype = torch.float16
elif config.model.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif config.model.mixed_precision == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
weight_dtype = get_weight_dtype(config.model.mixed_precision)
logger.info(f"Inference with {weight_dtype}, default guidance_type: {guidance_type}, flow_shift: {flow_shift}")

vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype)
vae_dtype = get_weight_dtype(config.vae.weight_dtype)
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(vae_dtype)
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device)

null_caption_token = tokenizer(
Expand All @@ -317,28 +311,7 @@ class SanaInference(SanaConfig):
null_caption_embs = text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]

# model setting
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
model_kwargs = {
"input_size": latent_size,
"pe_interpolation": config.model.pe_interpolation,
"config": config,
"model_max_length": config.text_encoder.model_max_length,
"qk_norm": config.model.qk_norm,
"micro_condition": config.model.micro_condition,
"caption_channels": text_encoder.config.hidden_size,
"y_norm": config.text_encoder.y_norm,
"attn_type": config.model.attn_type,
"ffn_type": config.model.ffn_type,
"mlp_ratio": config.model.mlp_ratio,
"mlp_acts": list(config.model.mlp_acts),
"in_channels": config.vae.vae_latent_dim,
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
"use_pe": config.model.use_pe,
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
}
model_kwargs = model_init_config(config, latent_size=latent_size)
model = build_model(
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs
).to(device)
Expand Down
Loading
Loading