Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Example of MLLM using MLX #207

Closed
fozziethebeat opened this issue Dec 31, 2023 · 12 comments
Closed

[Feature Request] Example of MLLM using MLX #207

fozziethebeat opened this issue Dec 31, 2023 · 12 comments
Labels
enhancement New feature or request

Comments

@fozziethebeat
Copy link

With ml-ferret out, it would be great to include an MLLM example in this repo, namely with ml-ferret or just LlaVA itself. Being LLAMA based, I think this would primarily require just implementing the image encoding step.

@awni awni added the enhancement New feature or request label Dec 31, 2023
@awni
Copy link
Member

awni commented Dec 31, 2023

Agreed, that would be super cool. If anyone is interested in contributing it let us know!

@gboduljak
Copy link
Contributor

Agreed, that would be super cool. If anyone is interested in contributing it let us know!

I would like to contribute to this example soon (perhaps this weekend). Some time ago, I proposed implementing ViT (CLIP) #143. Thus, I think it is cool to have an implementation of ViT (CLIP) that perfectly matches the one used in ferret. This enables us to quickly implement ferret. According to ferret implementation, the image encoder just wraps CLIP.

However, it is not the case that implementing ferret is just implementing image encoding. This is because ferret also uses bilinear interpolation using torch.functional.grid_sample. We do not have this in MLX currently. However, I recently implemented bilinear interpolation for Upsample2d(ml-explore/mlx#414).

@fozziethebeat
Copy link
Author

LLava might be easier to start with as v1.5 is literally just a fine tuned LLama model + CLIP + a small projector layer.

@gboduljak
Copy link
Contributor

LLava might be easier to start with as v1.5 is literally just a fine tuned LLama model + CLIP + a small projector layer.

I think this is a great idea. I will probably start with ViT-CLIP and then implement LLaVA.

@fozziethebeat
Copy link
Author

That would be awesome! At a bare minimum an image CLIP encoder will be really helpful. I saw that the Stable Diffusion example has a text CLIP encoder so the image half should just be a few small changes.

I was semi looking at this just to see how hard it is. For reference implementation, the original LLaVA implementation is likely to be easier to work with. The transformers v4.35 rewrite made quite a few changes.

@nkasmanoff
Copy link
Contributor

nkasmanoff commented Jan 14, 2024

Hey, posting in here since I'm interested in multi-modal. I'm currently trying to convert bakLlava. The model doesn't matter too much to me, but this one worked out of the box for me so went with it.

I'm trying to understand what's the best strategy for converting a model you want to MLX? Is there a guide somewhere?

Right now I'm currently looking at the model's config from HF https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/llava/modeling_llava.py#L237, and trying to build versions of these modules in MLX. Just started, but the idea is to make modules that look like the following, right?


class VisionTower(nn.Module):
    # TODO


class MultiModalProjector(nn.Module):
    # TODO


class LanguageModel(nn.Module):
    # TODO


class Model(nn.Module):
    def __init__(self, args: ModelArgs):

        self.vision_tower = VisionTower(args)
        self.multi_modal_projector = MultiModalProjector(args)

        self.language_model = LanguageModel(args)

        self.pad_token_id = args.pad_token_id

        self.vocab_size = args.vocab_size
        
    def __call__():
        # TODO 


(This is all I've done so far)

Curious if you all had a different strategy in mind? Seems like starting with CLIP is the best approach, so I'm also going to keep an eye on how that's going / let me know if you need any help.

@gboduljak
Copy link
Contributor

Hey, posting in here since I'm interested in multi-modal. I'm currently trying to convert bakLlava. The model doesn't matter too much to me, but this one worked out of the box for me so went with it.

I'm trying to understand what's the best strategy for converting a model you want to MLX? Is there a guide somewhere?

Right now I'm currently looking at the model's config from HF https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/llava/modeling_llava.py#L237, and trying to build versions of these modules in MLX. Just started, but the idea is to make modules that look like the following, right?


class VisionTower(nn.Module):
    # TODO


class MultiModalProjector(nn.Module):
    # TODO


class LanguageModel(nn.Module):
    # TODO


class Model(nn.Module):
    def __init__(self, args: ModelArgs):

        self.vision_tower = VisionTower(args)
        self.multi_modal_projector = MultiModalProjector(args)

        self.language_model = LanguageModel(args)

        self.pad_token_id = args.pad_token_id

        self.vocab_size = args.vocab_size
        
    def __call__():
        # TODO 

(This is all I've done so far)

Curious if you all had a different strategy in mind? Seems like starting with CLIP is the best approach, so I'm also going to keep an eye on how that's going / let me know if you need any help.

Hi @nkasmanoff :) I am also very interested in porting multimodal models. My end goal is to implement ferret in MLX and I will probably need some help. Here is my (draft) pull request for CLIP (#315). I think it is going well and I am confident that it will be done soon. After we have CLIP, it is relatively simple to port LLaVA to MLX.

In terms of the general approach, I think that a good starting point is to design model structure (e.g. your Python class hierarchy) such that it closely follows transformers implementations and HuggingFace weights. The next step is to write a script that converts HuggingFace weights to your weights. This resource is useful #155. Here is how I did it in #315. I based my work on the StableDiffusion example (https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).

@nkasmanoff
Copy link
Contributor

@gboduljak Thank you this is a lot of great info! Will try to catch myself up and help :-)

@gboduljak
Copy link
Contributor

gboduljak commented Jan 14, 2024

@gboduljak Thank you this is a lot of great info! Will try to catch myself up and help :-)

Thank you :)

Here are some concrete things were I need some help:

  1. Verifying whether the initialization of CLIP (CLIP (ViT) #315) is correct (i.e. matches the method implemented in https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/clip/)
  2. Complete implementation of CLIPTokenizer (the implementation in PR should work for most of the cases, but it is not complete)
  3. An implementation of CLIP image input pre-processing (we are now using transformers CLIPProcessor)
  4. Removing the dependency on transformers
  5. Proofreading README.md

If you have time, I would appreciate your general review of #315. Sooner we have the satisfactory CLIP implementation, the sooner we can move to LLaVa and/or Ferret.

@nkasmanoff
Copy link
Contributor

@gboduljak No problem, but after looking it over, not sure I can be extraordinarily helpful beyond some simpler tasks. This is a lot lower level coding than I'm used to :-). With that being said I can help with 3, 5, and a general review. Don't the tests you have already verify that 1. is complete?

I'm also considering making a demo train.py file, have you considered something like this? If I come up with something that looks like it works using the code you have here, I'll submit a PR to the branch you already made.

@gboduljak
Copy link
Contributor

gboduljak commented Jan 15, 2024

@gboduljak No problem, but after looking it over, not sure I can be extraordinarily helpful beyond some simpler tasks. This is a lot lower level coding than I'm used to :-). With that being said I can help with 3, 5, and a general review. Don't the tests you have already verify that 1. is complete?

I'm also considering making a demo train.py file, have you considered something like this? If I come up with something that looks like it works using the code you have here, I'll submit a PR to the branch you already made.

Thank you for taking a look at my code. The tests unfortunately do not verify that our model initialization is correct. They verify whether our model inference is correct. In other words, we test whether our implementation with pretrained weights produces the same output as the transformers implementation initialized with the same pretrained weights.

Porting CLIPProcessor would be very useful, because then we can eliminate the dependency on transformers.
train.py would also be very interesting. However, we likely cannot run the full training :).

In addition, please ask me to explain the code. Perhaps we should move the further discussion of CLIP to its issue #143.

@awni
Copy link
Member

awni commented Mar 2, 2024

Closing as we added #461.

@awni awni closed this as completed Mar 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants