Skip to content

Commit

Permalink
Turn off resizing images with --resize=False (#71)
Browse files Browse the repository at this point in the history
* Make image resize optional with --resize

Toggle off image resizing using --resize=False. Default is true for to maintain consistent operation.

* Make image resize optional with --resize

Toggle off image resizing using --resize=False. Default is true for to maintain consistent operation.

* Make image resize optional with --resize

Toggle off image resizing using --resize=False. Default is true for to maintain consistent operation.
  • Loading branch information
hdon96 authored Dec 24, 2022
1 parent 4869fe3 commit 39affb7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
15 changes: 14 additions & 1 deletion train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ def __init__(
size=512,
center_crop=False,
color_jitter=False,
resize=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
Expand All @@ -90,7 +92,9 @@ def __init__(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
)
if resize
else transforms.Lambda(lambda x: x),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
Expand Down Expand Up @@ -419,6 +423,13 @@ def parse_args(input_args=None):
default=None,
help=("File path for text encoder lora to resume training."),
)
parser.add_argument(
"--resize",
type=bool,
default=True,
required=False,
help="Should images be resized to --resolution before training?"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -648,6 +659,8 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
color_jitter=args.color_jitter,
resize=args.resize,

)

def collate_fn(examples):
Expand Down
14 changes: 13 additions & 1 deletion train_lora_pt_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ def __init__(
size=512,
center_crop=False,
color_jitter=False,
resize=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
Expand Down Expand Up @@ -109,7 +111,9 @@ def __init__(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
)
if resize
else transforms.Lambda(lambda x: x),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
Expand Down Expand Up @@ -482,6 +486,13 @@ def parse_args(input_args=None):
action="store_true",
help="Debug to see just ti",
)
parser.add_argument(
"--resize",
type=bool,
default=True,
required=False,
help="Should images be resized to --resolution before training?"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -749,6 +760,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
color_jitter=args.color_jitter,
resize=args.resize,
)

def collate_fn(examples):
Expand Down
16 changes: 14 additions & 2 deletions train_lora_w_ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,13 @@ def __init__(
size=512,
center_crop=False,
color_jitter=False,
resize=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize


self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
Expand Down Expand Up @@ -168,7 +171,9 @@ def __init__(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
)
if resize
else transforms.Lambda(lambda x: x),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
Expand Down Expand Up @@ -545,6 +550,13 @@ def parse_args(input_args=None):
action="store_true",
help="Debug to see just ti",
)
parser.add_argument(
"--resize",
type=bool,
default=True,
required=False,
help="Should images be resized to --resolution before training?"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -812,6 +824,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
color_jitter=args.color_jitter,
resize=args.resize,
)

def collate_fn(examples):
Expand Down Expand Up @@ -1104,7 +1117,6 @@ def collate_fn(examples):

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
Expand Down

0 comments on commit 39affb7

Please sign in to comment.