Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates for next release #264

Merged
merged 5 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions documentation/DEEPSPEED.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ ZeRO stands for **Zero Redundancy Optimizer**. This technique reduces the memory

ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see the original [paper](https://arxiv.org/abs/1910.02054v3) (1910.02054v3).

## Known issues

### LoRA support

Due to how DeepSpeed changes the model saving routines, it's not currently supported to train LoRA with DeepSpeed.

This may change in a future release.

### Enabling / disabling DeepSpeed on existing checkpoints

Currently in SimpleTuner, DeepSpeed cannot be **enabled** when resuming from a checkpoint that did **not** previously use DeepSpeed.

Conversely, DeepSpeed cannot be **disabled** when attempting to resume training from a checkpoint that was trained using DeepSpeed.

To workaround this issue, export the training pipeline to a complete set of model weights before attempting to enable/disable DeepSpeed on an in-progress training session.

It's unlikely this support will ever come to fruition, as DeepSpeed's optimiser is very different from any of the usual choices.

## DeepSpeed Stages

DeepSpeed offers three levels of optimisation for training a model, with each increase having more and more overhead.
Expand Down
2 changes: 1 addition & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def parse_args(input_args=None):
Path(args.cache_dir_text),
]:
os.makedirs(target_dir, exist_ok=True)
logger.info(f"VAE Cache location: {args.cache_dir_vae}")
logger.info(f"Default VAE Cache location: {args.cache_dir_vae}")
logger.info(f"Text Cache location: {args.cache_dir_text}")

if args.validation_resolution < 128:
Expand Down
7 changes: 5 additions & 2 deletions helpers/sdxl/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ def __init__(
text_encoder_1,
text_encoder_2,
accelerator,
use_deepspeed_optimizer,
):
self.args = args
self.unet = unet
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
self.ema_unet = ema_unet
self.accelerator = accelerator
self.use_deepspeed_optimizer = use_deepspeed_optimizer

def save_model_hook(self, models, weights, output_dir):
# Write "training_state.json" to the output directory containing the training state
Expand Down Expand Up @@ -66,11 +68,12 @@ def save_model_hook(self, models, weights, output_dir):
get_peft_model_state_dict(model)
)
)
else:
elif not self.use_deepspeed_optimizer:
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusionXLPipeline.save_lora_weights(
output_dir,
Expand Down
47 changes: 36 additions & 11 deletions toolkit/captioning/caption_with_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,21 @@ def parse_args():


# Function to load LLaVA model
def load_llava_model(model_path: str = "llava-hf/llava-1.5-7b-hf"):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
def load_llava_model(
model_path: str = "llava-hf/llava-1.5-7b-hf", precision: str = "fp4"
):
bnb_config = BitsAndBytesConfig()
if precision == "fp4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif precision == "fp8":
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
)
model = LlavaForConditionalGeneration.from_pretrained(
model_path, quantization_config=bnb_config, device_map="auto"
)
Expand Down Expand Up @@ -160,7 +168,7 @@ def resize_for_condition_image(input_image: Image, resolution: int):
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img

return eval_model(args, resize_for_condition_image(image, 384), model, processor)
return eval_model(args, resize_for_condition_image(image, 256), model, processor)


# Function to convert content to filename
Expand Down Expand Up @@ -206,7 +214,7 @@ def process_directory(args, image_dir, output_dir, progress_file, model, process

for filename in tqdm(os.listdir(image_dir), desc="Processing Images"):
full_filepath = os.path.join(image_dir, filename)
if filename in processed_files[image_dir]:
if image_dir in processed_files and filename in processed_files[image_dir]:
logging.info(f"File has already been processed: {filename}")
continue

Expand All @@ -219,7 +227,7 @@ def process_directory(args, image_dir, output_dir, progress_file, model, process
try:
logging.info(f"Attempting to load image: {filename}")
with Image.open(full_filepath) as image:
logging.info(f"Processing image: {filename}")
logging.info(f"Processing image: {filename}, data: {image}")
best_match = process_and_evaluate_image(
args, full_filepath, model, processor
)
Expand Down Expand Up @@ -256,6 +264,23 @@ def process_directory(args, image_dir, output_dir, progress_file, model, process

except Exception as e:
logging.error(f"Error processing {filename}: {str(e)}")
if "CUDA error" in str(e):
import sys

sys.exit(1)
if "name too long" in str(e):
# Loop and try to reduce the filename length until it works:
exception_error = str(e)
while "name too long" in exception_error:
# Cut the word down by one character:
new_filename = new_filename[:-1]
try:
new_filepath = os.path.join(output_dir, new_filename)
# Try to save again
image.save(new_filepath)
exception_error = ""
except Exception as e:
exception_error = str(e)


def main():
Expand All @@ -268,7 +293,7 @@ def main():
os.makedirs(args.output_dir)

# Load model
model, processor = load_llava_model(args.model_path)
model, processor = load_llava_model(args.model_path, args.precision)

# Process directory
process_directory(
Expand Down
7 changes: 5 additions & 2 deletions toolkit/datasets/csv_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,11 @@ def process_and_upload(image_path_args):
# Resize image for condition if required
image = resize_for_condition_image(image, args.condition_image_size)
temp_path = os.path.join(args.temporary_folder, os.path.basename(image_path))
image.save(temp_path, format="PNG")
image.close()
try:
image.save(temp_path, format="PNG")
image.close()
except:
logger.error(f"Error saving image")
# Upload to S3
upload_local_image_to_s3(temp_path, args, s3_client)

Expand Down
69 changes: 49 additions & 20 deletions toolkit/datasets/dataset_from_kellyc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,50 @@ def get_photo_id(url):
return match.group(1) if match else None


def download_smallest_image(urls, output_path, minimum_image_size: int):
"""Download the smallest image based on width."""
smallest_url = min(urls, key=get_image_width)
response = requests.get(smallest_url, stream=True)
conn_timeout = 6
read_timeout = 60
timeouts = (conn_timeout, read_timeout)


def download_image(url, output_path, minimum_image_size: int, minimum_pixel_area: int):
"""Download an image."""
response = requests.get(url, timeout=timeouts, stream=True)

if response.status_code == 200:
filename = os.path.basename(smallest_url.split("?")[0])
filename = os.path.basename(url.split("?")[0])
file_path = os.path.join(output_path, filename)
# Convert path to PNG:
file_path = file_path.replace(".jpg", ".png")

with open(file_path, "wb") as f:
for chunk in response.iter_content(1024):
f.write(chunk)
# Is the file >= 1024px on both sides?
# Check if the file meets the minimum size requirements
image = Image.open(file_path)
width, height = image.size
if width < minimum_image_size or height < minimum_image_size:
if minimum_image_size > 0 and (
width < minimum_image_size or height < minimum_image_size
):
os.remove(file_path)
return f"Nuked tiny image: {smallest_url}"
return f"Nuked tiny image: {url}"
if minimum_pixel_area > 0 and (width * height < minimum_pixel_area):
os.remove(file_path)
return f"Nuked tiny image: {url}"

return f"Downloaded: {url}"
return f"Failed to download: {url}"

return f"Downloaded: {smallest_url}"
return f"Failed to download: {smallest_url}"

def process_urls(urls, output_path, minimum_image_size: int, minimum_pixel_area: int):
"""Process a list of URLs."""
# Simple URL list
results = []
for url in urls:
result = download_image(
url, output_path, minimum_image_size, minimum_pixel_area
)
results.append(result)
return "\n".join(results)


def main(args):
Expand All @@ -52,18 +75,17 @@ def main(args):
with open(args.file_path, "r") as file:
for line in file:
urls = line.strip().split()
for url in urls:
photo_id = get_photo_id(url)
if photo_id:
if photo_id not in url_groups:
url_groups[photo_id] = []
url_groups[photo_id].append(url)

# Using ThreadPoolExecutor to parallelize downloads
# Treat as a simple URL list
url_groups[line] = urls

with ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = [
executor.submit(
download_smallest_image, urls, args.output_path, args.minimum_image_size
process_urls,
urls,
args.output_path,
args.minimum_image_size,
args.minimum_pixel_area,
)
for urls in url_groups.values()
]
Expand All @@ -89,7 +111,14 @@ def main(args):
parser.add_argument(
"--minimum_image_size",
type=int,
help="Both sides of the image must be larger than this.",
default=0,
help="Both sides of the image must be larger than this. ZERO disables this.",
)
parser.add_argument(
"--minimum_pixel_area",
type=int,
default=0,
help="The total number of pixels in the image must be larger than this. ZERO disables this. Recommended value: 1024*1024",
)
parser.add_argument(
"--workers",
Expand Down
50 changes: 50 additions & 0 deletions toolkit/inference/tile_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO

# Placeholder URL
url = "https://sa1s3optim.patientpop.com/assets/images/provider/photos/2353184.jpg"

# Download the image from the URL
response = requests.get(url)
original_image = Image.open(BytesIO(response.content))

# Define target size (1 megapixel)
target_width = 1000
target_height = 1000

# Resize the image using different samplers
samplers = {
"NEAREST": Image.NEAREST,
"BOX": Image.BOX,
"HAMMING": Image.HAMMING,
"BILINEAR": Image.BILINEAR,
"BICUBIC": Image.BICUBIC,
"LANCZOS": Image.LANCZOS,
}

# Create a new image to combine the results
combined_width = target_width * len(samplers)
combined_height = target_height + 50 # Extra space for labels
combined_image = Image.new("RGB", (combined_width, combined_height), "white")
draw = ImageDraw.Draw(combined_image)

# Load a default font
try:
font = ImageFont.load_default()
except IOError:
font = None

# Resize and add each sampler result to the combined image
for i, (label, sampler) in enumerate(samplers.items()):
resized_image = original_image.resize((target_width, target_height), sampler)
combined_image.paste(resized_image, (i * target_width, 50))

# Draw the label
text_position = (i * target_width + 20, 15)
draw.text(text_position, label, fill="black", font=font)

# Save or display the combined image
combined_image_path = "downsampled_image_comparison.png"
combined_image.save(combined_image_path)
combined_image_path
9 changes: 7 additions & 2 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
from helpers import log_format

import shutil, hashlib, json, copy, random, logging, math, os
import shutil, hashlib, json, copy, random, logging, math, os, sys

# Quiet down, you.
os.environ["ACCELERATE_LOG_LEVEL"] = "WARNING"
Expand Down Expand Up @@ -197,6 +197,11 @@ def main():
hasattr(accelerator.state, "deepspeed_plugin")
and accelerator.state.deepspeed_plugin is not None
):
if args.model_type == "lora":
logger.error(
"LoRA can not be trained with DeepSpeed. Please disable DeepSpeed via 'accelerate config' before reattempting."
)
sys.exit(1)
if (
"gradient_accumulation_steps"
in accelerator.state.deepspeed_plugin.deepspeed_config
Expand Down Expand Up @@ -416,7 +421,6 @@ def main():
configure_multi_databackend(args, accelerator)
except Exception as e:
logging.error(f"{e}")
import sys

sys.exit(0)

Expand Down Expand Up @@ -752,6 +756,7 @@ def main():
accelerator=accelerator,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
use_deepspeed_optimizer=use_deepspeed_optimizer,
)
accelerator.register_save_state_pre_hook(model_hooks.save_model_hook)
accelerator.register_load_state_pre_hook(model_hooks.load_model_hook)
Expand Down
Loading