From da16e670e85a72527427df4c6f30b2108f241327 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 5 Oct 2024 20:24:38 +0800 Subject: [PATCH] Fix minor bugs. --- keras_hub/src/models/image_to_image.py | 5 +++++ keras_hub/src/models/stable_diffusion_3/mmdit.py | 12 ++++++++++++ .../stable_diffusion_3_image_to_image.py | 4 ++-- .../stable_diffusion_3_image_to_image_test.py | 2 +- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py index 2139b1af5..99dda6993 100644 --- a/keras_hub/src/models/image_to_image.py +++ b/keras_hub/src/models/image_to_image.py @@ -268,6 +268,11 @@ def generate( ): """Generate image based on the provided `images` and `inputs`. + The `images` are reference images that will be resized to + `self.backbone.height` and `self.backbone.width`, then encoded into + latent space by the VAE encoder. The `inputs` are strings that will be + tokenized and encoded by the text encoder. + If `images` and `inputs` are a `tf.data.Dataset`, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will be processed as batches. diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 0a618a427..722bfdf27 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -252,6 +252,17 @@ def call(self, inputs, height=None, width=None): position_embedding = ops.expand_dims(position_embedding, axis=0) return position_embedding + def get_config(self): + config = super().get_config() + del config["sequence_length"] + config.update( + { + "height": self.height, + "width": self.width, + } + ) + return config + def compute_output_shape(self, input_shape): return input_shape @@ -321,6 +332,7 @@ def get_config(self): config.update( { "embedding_dim": self.embedding_dim, + "frequency_dim": self.frequency_dim, "max_period": self.max_period, } ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py index 7a6714c65..3d551eb5e 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -14,8 +14,8 @@ class StableDiffusion3ImageToImage(ImageToImage): """An end-to-end Stable Diffusion 3 model for image-to-image generation. - This model has a `generate()` method, which generates image based on a pair - of image and prompt. + This model has a `generate()` method, which generates images based + on a combination of a reference image and a text prompt. Args: backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py index 1f9d4c19d..7374ea8e8 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -73,7 +73,7 @@ def setUp(self): "guidance_scale": ops.ones((2,)), } - def test_text_to_image_basics(self): + def test_image_to_image_basics(self): pytest.skip( reason="TODO: enable after preprocessor flow is figured out" )