Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL with Compel: unexpected keyword argument output_hidden_states #357

Closed
neo opened this issue Nov 27, 2023 · 22 comments · Fixed by #581
Closed

SDXL with Compel: unexpected keyword argument output_hidden_states #357

neo opened this issue Nov 27, 2023 · 22 comments · Fixed by #581
Assignees

Comments

@neo
Copy link
Contributor

neo commented Nov 27, 2023

From looking at the code, I could tell that output_hidden_states is always true for SDXL's text encoders; however in compel, it's always explicitly passed (even though it's also True for SDXL): https://github.com/damian0815/compel/blob/v2.0.2/src/compel/embeddings_provider.py#L392. And it's causing the NeuronModelTextEncoder to error out unexpected keyword argument 'output_hidden_states'.

So I'm wondering, what's the best way to approach this as well as possibly a similar issue with return_dict? Thanks!

@JingyaHuang
Copy link
Collaborator

Hi @neo,

Whether the text encoder can output the hidden_states or not is defined during the compilation, and won't be adjustable during the inference. Does your checkpoint config set explicitly output_hidden_states to True? If not (by default in transformers) , the traced text encoder won't output hidden_states.

And the traced model can only take positional arguments and output a tuple, so in optimum, we put up a wrapper to control the keyword arguments and ensure they are passed to the model with the same order as it was traced.

class NeuronModelTextEncoder(_NeuronDiffusionModelPart):
def __init__(
self,
model: torch.jit._script.ScriptModule,
parent_model: NeuronBaseModel,
config: Optional[DiffusersPretrainedConfig] = None,
neuron_config: Optional[Dict[str, str]] = None,
):
super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_TEXT_ENCODER_NAME)
def forward(self, input_ids: torch.Tensor):
inputs = (input_ids,)
outputs = self.model(*inputs)
return outputs

So if you need output_hidden_states=True, the first thing to do is to ensure if enabled while tracing. And during the inference with compel, you could wrap the text encoder with dummy keyword arguments for output_hidden_states and return_dict(pass the output to BaseModelOutputWithPooling).

P.S. I did not expect the components to be used outside of diffusers, and it makes sense to me to support natively these extra arguments in optimum-neuron, would you like to file me a PR and make it natively in optimum?

@JingyaHuang JingyaHuang self-assigned this Nov 28, 2023
@neo
Copy link
Contributor Author

neo commented Nov 28, 2023

Hi @JingyaHuang, thank you for the response!

I had the same understanding that these extra options won't really affect the post-compilation run-time, so I guess my questions was:

  • to confirm there's a way to set these options as needed; I saw output_hidden_states is from below, but how about return_dict?
    if is_sdxl:
    pipeline.text_encoder.config.output_hidden_states = True
    models_for_export.append((DIFFUSION_MODEL_TEXT_ENCODER_NAME, copy.deepcopy(pipeline.text_encoder)))
    text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
    if text_encoder_2 is not None:
    text_encoder_2.config.output_hidden_states = True
  • and secondly, I was thinking in the same direction to have a dummy shim to wrap around something to work around this (potentially with assertion that they match the compile-time config); could you point me to where you think would be the best place to do it?

And I'd be more than happy to send in a PR once I get something working and then have a generalized solution.

@neo
Copy link
Contributor Author

neo commented Nov 28, 2023

I was able to setup a temporary shim – however, other than output_hidden_states and return_dict (which I still don't know if it's True from a default compilation), there's also attention_mask that's not compiled in: https://github.com/damian0815/compel/blob/v2.0.2/src/compel/embeddings_provider.py#L390-L393

I'm wondering is there a quick way to do a compilation matching this signature and the amount of arguments for the text encoders?

Thanks!

@JingyaHuang
Copy link
Collaborator

Hi @neo,

  • The TorchScriptModule cannot return a dict as output, but with a wrapper like NeuronModelTextEncoder, we can reformulate the output tuple by passing it to a subclass of ModelOutput, and return a dictionary, through this we can able return_dict.

  • I am working on a PR supporting encoder-decoder models [Inference] Add t5 support for export and inference  #267. To set up an example, I added optional_outputs support for decoders, it might give you some ideas on supporting it for stable diffusion. Basically, we need to configure what we want as output with NeuronConfig. That PR is still in progress, let me get back to you once we have a clear scheme.

@neo
Copy link
Contributor Author

neo commented Nov 29, 2023

Hi @JingyaHuang! How about the attention_mask mask? is it possible to also include this argument for the text encoder forward function? I'm more uncertain about this because it's an actual input rather than just controlling the output.

@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Nov 29, 2023

I was able to setup a temporary shim – however, other than output_hidden_states and return_dict (which I still don't know if it's True from a default compilation), there's also attention_mask that's not compiled in: https://github.com/damian0815/compel/blob/v2.0.2/src/compel/embeddings_provider.py#L390-L393

I'm wondering is there a quick way to do a compilation matching this signature and the amount of arguments for the text encoders?

Thanks!

Yes attention_mask is possible to add as an optional input.

In general, for optimum, we do not support by default every possible input as it sometimes comes with extra compute cost. But we try to support different use cases, by letting users customize whatever is possible.

@neo
Copy link
Contributor Author

neo commented Nov 29, 2023

That makes total sense, I do understand it might not be a super common use case for the library to bake this in.

While acknowledging this is not your job or obligation, would you be generous enough to point me to how I could do a compilation accepting the attention_mask input? (assuming this needs to happen while tracing) 😳

@neo
Copy link
Contributor Author

neo commented Dec 7, 2023

update: based on your suggestions, I was able to figure out the output_hidden_states and return_dict with a dummy shim class 😆

now there's only attention_mask that I'm still trying to figure out how to do... I did noticed in the PR you provided (https://github.com/huggingface/optimum-neuron/pull/267/files#diff-9a335bc75a3caafe37679f808893160ac22d346773c0b0aa0710f650ef7f6f89R85-R86), the TextSeq2SeqNeuronConfig does allow the secondary attention_mask encoder input argument.

I'll try to see what can I do about that, let me know if there's any suggestions/recommendations 😊

@neo
Copy link
Contributor Author

neo commented Dec 8, 2023

I inspected that, for our use case, the attention_mask being passed from compel is actually usually just all ones; so I'm guessing we could just ignore it and not pass it into the text encoder call, and it wouldn't affect the output? Not sure how common/universal is that to justify it being a suitable patch for the library.

I also noticed in the neuron config of the text encoder, there are output names specified, which was very helping in confirming which output tuple element is which for assembling the return_dict. However, I also noticed the neuron config of text_encoder_2 has the incorrect class thus incorrect outputs albeit the config json has the correct output_names, and I opened a new issue for that: #371

@neo
Copy link
Contributor Author

neo commented Dec 8, 2023

Not sure how common/universal is that to justify it being a suitable patch for the library.

It seems like for SD, the attention mask would always be all ones? https://github.com/damian0815/compel/blob/v2.0.2/src/compel/embeddings_provider.py#L47

padding_attention_mask_value: Value to write into the attention mask for padding tokens. Stable Diffusion needs 1.

@Suprhimp
Copy link

Hi, Can I ask what to do, If I want to use SDXL with compel?

I'm stucked in this error

  File "/home/ubuntu/.local/lib/python3.8/site-packages/compel/embeddings_provider.py", line 390, in _encode_token_ids_to_embeddings
    text_encoder_output = self.text_encoder(token_ids,
  File "/home/ubuntu/.local/lib/python3.8/site-packages/optimum/neuron/modeling_diffusion.py", line 718, in __call__
    return self.forward(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'output_hidden_states'

what should I do when I compile the model.
when I compile with optimum-cli, with compile flag output_hidden_states True it said it dose not support yet.

segment fault error even I tried with optimum.neuron python module..

can you guys gives me some hint of this? ;)

@HuggingFaceDocBuilderDev

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Thank you!

@neo
Copy link
Contributor Author

neo commented Apr 13, 2024

not stale

@Suprhimp
Copy link

Yep.. I think so too.
I want to run stablediffusion model with compel but I can't.
And I think compel is quite important for running.

@JingyaHuang
Copy link
Collaborator

If compel is a widely used option happy to support it.

@JingyaHuang
Copy link
Collaborator

Hi @neo and @Suprhimp, thanks for being active in the discussion. Do you have a reproducible script using Optimum Neuron with Compel that I can play with? So that I can have a better idea of how to make Optimum Neuron compatible with it, THX :>

@neo
Copy link
Contributor Author

neo commented Apr 23, 2024

I had a semi-ready fix @JingyaHuang 😆 let me verify #371 tomorrow and I'll try to do a PR

@Suprhimp
Copy link

Suprhimp commented Apr 24, 2024

Is this looks good? @JingyaHuang I usually use compel in normal stablediffusion like this method.

from optimum.neuron import NeuronStableDiffusionXLPipeline
from compel import Compel, ReturnedEmbeddingsType

pipe = NeuronStableDiffusionXLPipeline.from_pretrained("sdxl_turbo_neuron/", data_parallel_mode="all")
prompt = "Self-portrait oil painting++++++++++, a beautiful cyborg with golden hair, 8k"
negative_prompt = "worst quality"

compel = Compel(
        tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
        text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
        returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
        requires_pooled=[False, True],
    )
prompt_embeds, pooled = compel(prompt)
neg_prompt_embeds, neg_pooled = compel(negative_prompt)

images = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=neg_prompt_embeds,
        pooled_prompt_embeds=pooled,
        negative_pooled_prompt_embeds=neg_pooled, 
        guidance_scale=0.0, num_inference_steps=1).images

@JingyaHuang
Copy link
Collaborator

That will be awesome, thx @neo! Let me know if you need any support.

And @Suprhimp thanks for the example!

@JingyaHuang
Copy link
Collaborator

Hey folks, I just prepared a poc supporting the feature in a dev branch here: https://github.com/huggingface/optimum-neuron/tree/lez-support-compel.

It's just a very draft one that works with the example sent by @Suprhimp, and might help @neo with the PR that you would like to submit. Please tell me if that meets what's needed for supporting compel. ty!

@neo
Copy link
Contributor Author

neo commented Apr 24, 2024

I think it depends on the scope of changes you'd like – for the solution I had, the exporting side didn't need to be changed, because it's default to "output_hidden_states": true which is what compel needs. So in my case I just needed to update the forward method in NeuronModelTextEncoder to handle the return_dict and assert the output_hidden_states to be consistent with what's compiled.

Edit: so it's up to you if we'd like to include the option also in compilation to potentially open doors to other future integrations.

@JingyaHuang
Copy link
Collaborator

Thanks for the insight @neo !

@neo @Suprhimp Do you mind reviewing the pr here? If it goes well, let's make the sd pipelines compel-compatible in the next Optimum Neuron release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants