Skip to content

Commit

Permalink
Fix minor bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Oct 5, 2024
1 parent a7cc7f2 commit da16e67
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
5 changes: 5 additions & 0 deletions keras_hub/src/models/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def generate(
):
"""Generate image based on the provided `images` and `inputs`.
The `images` are reference images that will be resized to
`self.backbone.height` and `self.backbone.width`, then encoded into
latent space by the VAE encoder. The `inputs` are strings that will be
tokenized and encoded by the text encoder.
If `images` and `inputs` are a `tf.data.Dataset`, outputs will be
generated "batch-by-batch" and concatenated. Otherwise, all inputs will
be processed as batches.
Expand Down
12 changes: 12 additions & 0 deletions keras_hub/src/models/stable_diffusion_3/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def call(self, inputs, height=None, width=None):
position_embedding = ops.expand_dims(position_embedding, axis=0)
return position_embedding

def get_config(self):
config = super().get_config()
del config["sequence_length"]
config.update(
{
"height": self.height,
"width": self.width,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape

Expand Down Expand Up @@ -321,6 +332,7 @@ def get_config(self):
config.update(
{
"embedding_dim": self.embedding_dim,
"frequency_dim": self.frequency_dim,
"max_period": self.max_period,
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
class StableDiffusion3ImageToImage(ImageToImage):
"""An end-to-end Stable Diffusion 3 model for image-to-image generation.
This model has a `generate()` method, which generates image based on a pair
of image and prompt.
This model has a `generate()` method, which generates images based
on a combination of a reference image and a text prompt.
Args:
backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def setUp(self):
"guidance_scale": ops.ones((2,)),
}

def test_text_to_image_basics(self):
def test_image_to_image_basics(self):
pytest.skip(
reason="TODO: enable after preprocessor flow is figured out"
)
Expand Down

0 comments on commit da16e67

Please sign in to comment.