Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release training code #99

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions configs/train/stage1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
exp_name: 'stage1'
output_dir: './exp_output'
seed: 42
resume_from_checkpoint: ''

checkpointing_steps: 2000
save_model_epoch_interval: 20

data:
train_bs: 4
video_folder: '' # Your data root folder
guids:
- 'depth'
- 'normal'
- 'semantic_map'
- 'dwpose'
image_size: 768
bbox_crop: false
bbox_resize_ratio: [0.9, 1.5]
aug_type: "Resize"
data_parts:
- "all"
sample_margin: 30

validation:
validation_steps: 1000
ref_images:
- validation_data/ref_images/val-0.png
guidance_folders:
- validation_data/guid_sequences/0
guidance_indexes: [0, 30, 60, 90, 120]

solver:
gradient_accumulation_steps: 1
mixed_precision: 'fp16'
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: False
max_train_steps: 100000 # 50000
max_grad_norm: 1.0
# lr
learning_rate: 1.0e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: 'constant'

# optimizer
use_8bit_adam: False
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "scaled_linear"
steps_offset: 1
clip_sample: false

guidance_encoder_kwargs:
guidance_embedding_channels: 320
guidance_input_channels: 3
block_out_channels: [16, 32, 96, 256]

base_model_path: 'pretrained_models/stable-diffusion-v1-5'
vae_model_path: 'pretrained_models/sd-vae-ft-mse'
image_encoder_path: 'pretrained_models/image_encoder'

weight_dtype: 'fp16' # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True

103 changes: 103 additions & 0 deletions configs/train/stage2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
exp_name: 'stage2'
output_dir: './exp_output'
seed: 42
resume_from_checkpoint: ''

stage1_ckpt_step: 'latest'
stage1_ckpt_dir: '' # stage1 checkpoint folder

checkpointing_steps: 2000
save_model_epoch_interval: 20

data:
train_bs: 1
video_folder: '' # Your data root folder
guids:
- 'depth'
- 'normal'
- 'semantic_map'
- 'dwpose'
image_size: 512
bbox_crop: false
bbox_resize_ratio: [0.9, 1.5]
aug_type: "Resize"
data_parts:
- "all"
sample_frames: 24
sample_rate: 4

validation:
validation_steps: 1000
clip_length: 24
ref_images:
- validation_data/ref_images/val-1.png
guidance_folders:
- validation_data/guid_sequences/0
guidance_indexes: [0, 30, 60, 90, 120]

solver:
gradient_accumulation_steps: 1
mixed_precision: 'fp16'
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: True
max_train_steps: 50000
max_grad_norm: 1.0
# lr
learning_rate: 1e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: 'constant'

# optimizer
use_8bit_adam: True
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false

guidance_encoder_kwargs:
guidance_embedding_channels: 320
guidance_input_channels: 3
block_out_channels: [16, 32, 96, 256]

unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1

base_model_path: 'pretrained_models/stable-diffusion-v1-5'
vae_model_path: 'pretrained_models/sd-vae-ft-mse'
image_encoder_path: 'pretrained_models/image_encoder'
mm_path: './pretrained_models/mm_sd_v15_v2.ckpt'

weight_dtype: 'fp16' # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
Empty file added datasets/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions datasets/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
import json
import random
from typing import List
import csv
import glob
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
from tqdm import tqdm


def process_bbox(bbox, H, W, scale=1.):
# transform a bbox(xmin, ymin, xmax, ymax) to (H, W) square
x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min

side_length = max(width, height)

center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2

scaled_side_length = side_length * scale
scaled_xmin = center_x - scaled_side_length / 2
scaled_xmax = center_x + scaled_side_length / 2
scaled_ymin = center_y - scaled_side_length / 2
scaled_ymax = center_y + scaled_side_length / 2

scaled_xmin = int(max(0, scaled_xmin))
scaled_xmax = int(min(W, scaled_xmax))
scaled_ymin = int(max(0, scaled_ymin))
scaled_ymax = int(min(H, scaled_ymax))

return scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax

def crop_bbox(img, bbox, do_resize=False, size=512):

if isinstance(img, (Path, str)):
img = Image.open(img)
cropped_img = img.crop(bbox)
if do_resize:
cropped_W, cropped_H = cropped_img.size
ratio = size / max(cropped_W, cropped_H)
new_W = cropped_W * ratio
new_H = cropped_H * ratio
cropped_img = cropped_img.resize((new_W, new_H))

return cropped_img

def mask_to_bbox(mask_path):
mask = np.array(Image.open(mask_path))[..., 0]
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)

ymin, ymax = np.where(rows)[0][[0, -1]]
xmin, xmax = np.where(cols)[0][[0, -1]]
return xmin, ymin, xmax, ymax

def mask_to_bkgd(img_path, mask_path):
img = Image.open(img_path)
img_array = np.array(img)

mask = Image.open(mask_path).convert("RGB")
mask_array = np.array(mask)

img_array = np.where(mask_array > 0, img_array, 0)
return Image.fromarray(img_array)

Loading