diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 8e951db68cc8..04a150d52c0a 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -205,6 +205,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, + max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -224,7 +225,7 @@ def _get_t5_prompt_embeds( text_inputs = self.tokenizer_3( prompt, padding="max_length", - max_length=self.tokenizer_max_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", @@ -235,8 +236,8 @@ def _get_t5_prompt_embeds( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] @@ -323,6 +324,7 @@ def encode_prompt( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, + max_sequence_length: int = 256, ): r""" @@ -403,6 +405,7 @@ def encode_prompt( t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -456,7 +459,10 @@ def encode_prompt( negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( - prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( @@ -486,6 +492,7 @@ def check_inputs( pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -557,6 +564,9 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + def prepare_latents( self, batch_size, @@ -643,6 +653,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. @@ -726,6 +737,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: @@ -753,6 +765,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale @@ -790,6 +803,7 @@ def __call__( device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, ) if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index e7b5ebc92a55..1ad8a8f3c42b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -220,6 +220,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, + max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -239,7 +240,7 @@ def _get_t5_prompt_embeds( text_inputs = self.tokenizer_3( prompt, padding="max_length", - max_length=self.tokenizer_max_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", @@ -250,8 +251,8 @@ def _get_t5_prompt_embeds( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] @@ -340,6 +341,7 @@ def encode_prompt( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, + max_sequence_length: int = 256, ): r""" @@ -420,6 +422,7 @@ def encode_prompt( t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -473,7 +476,10 @@ def encode_prompt( negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( - prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( @@ -502,6 +508,7 @@ def check_inputs( pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -573,6 +580,9 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -686,6 +696,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. @@ -765,6 +776,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: @@ -788,6 +800,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale @@ -824,6 +837,7 @@ def __call__( device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, ) if self.do_classifier_free_guidance: