Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bounty] PyTorch & HuggingFace Interface #139

Open
wants to merge 548 commits into
base: main
Choose a base branch
from

Conversation

risingsunomi
Copy link

Hello all,

I’ve made some updates to the exo library based on the bounty mentioned in this tweet/X post. These changes aim to integrate PyTorch and expand access to various language models through Hugging Face’s AutoModelForCausalLM.

What's New?

  • ShardedHuggingFaceModel: Adds sharding support for Hugging Face models.
  • PyTorchDynamicShardInferenceEngine: A new inference engine that uses PyTorch tensors for dynamic sharding.

These updates enable the exo library to use PyTorch, allowing access to a broader range of language models.

Limitations and Bugs

Right now the ShardedHuggingFaceModel is focused on using LlamaForCausalLM from the huggingface transformers library. From that model we break it up using LLamaModel and the layers it contains. We can then select the layers and run the pytorch tensors over them as need. I focused on using llama3.1 8B as I could only slightly run that.

Due to my current hardware limitations (specifically GPU and VRAM), I wasn’t able to fully test this across multiple nodes. The model currently takes about 30 seconds per token to generate for me (I have slow GPUs), which might be related to the absence of caching (not implemented due to VRAM constraints). It’s running without reaching an EOT and the outputs seem random.

Request for Feedback

I’m sharing this in the hope that others can test it on more capable setups and provide feedback on how to enhance performance and stability.

Important Note on Meta LLaMA 3.1 Model

If you plan to test with the official Meta LLaMA 3.1 model, please note:

  • Access: You’ll need to request access and authenticate using huggingface-cli to download it.
  • Command: Run the following command before using the model:
    huggingface-cli login
    
    I’m exploring ways to simplify this process, but for now, it’s necessary.

Chat API Update

  • Added an option to select the LLaMA 3.1 model in the chat API.

Looking forward to any feedback or suggestions you might have.

Thank you

@AlexCheema
Copy link
Contributor

Hey, sorry for the delay. I haven't had a chance to check this properly yet. I'll be able to look next week.

@risingsunomi
Copy link
Author

Hey, sorry for the delay. I haven't had a chance to check this properly yet. I'll be able to look next week.

Sounds good. Let me know anything needed. Thank you


# self.past_key_values = DynamicCache()

def forward_layers(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this approach of generalising this so it works for other models without having to explicitly implement them.

Can you write a test for a model with a different architecture to make sure this generalises e.g. recurrent Gemma?
I wonder if we need a little bit of model-specific behaviour to enable this in general?

async for chunk in response.content.iter_chunked(8192):
f.write(chunk)

async def download_files(urls: List[str], output_paths: List[Path]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I can remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove

Sorry forgot to. Will do that now.

self.shard = None
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the only options? I think supporting e.g. Mac with mps would be great since then you can run heterogeneous clusters.

One thing to try at some point would be mixing MLX and PyTorch and see if they are interoperable with exactly the same model.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With pytorch I don't think mac is fully rolled out yet. There seems to be some work arounds but CUDA and CPU are the only options on the pytorch download website. pytorch even stopped ROCm support for AMD

They have a nightly for testing MPS https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this in the official "stable" docs: https://pytorch.org/docs/stable/notes/mps.html

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try that but currently no mac to test. When I get through these other fixes though I can definitely add it for you or other mac users to test.

# Load the model
self.full_model = AutoModelForCausalLM.from_pretrained(
shard.model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, are these the only options? Would want support across other platforms


layers.append(layer)

self.full_model.model.layers = nn.ModuleList(layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the peak memory usage look like here? I'm not sure of the specifics of python if this is going to hold each layer twice. Not sure but perhaps setting them in place would be more memory efficient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They shouldn't be held twice as when the ensure_shard function is called in the infer_prompt or infer_tensor the init class function is called which loads the needed layers each time depending on the shard. Will make sure about memory limits though and usage.



# Load the model
self.full_model = AutoModelForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this download the entire model?
We have code to selectively download the model from HuggingFace so you don't have to download all layers on every device: exo/download

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this won't work with our download progress code. We show in the TUI what the download progress of the model is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this download the entire model? We have code to selectively download the model from HuggingFace so you don't have to download all layers on every device: exo/download

Will look at using that code because yes it currently does download all the model

import torch
from torch.nn import functional as F

def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be imported from somewhere rather than copy-pasta into the codebase? It looks like boilerplate code from somewhere.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was testing it as the default values but will clean that part up. I will set it in the Interface class settings to be used.

& .\.venv\Scripts\Activate.ps1

# Install the package in the virtual environment
pip install .
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Windows? Did this work on windows? Curious if it works there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was testing on Windows but couldn't fully get it working right. Will test again and make sure as I switched to using Linux to do further dev

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexCheema I can confirm that it is working on Windows, but there are a few issues:

  1. PyTorch 2.4 doesn't install. In order to get it working, it needs the nightly build

  2. In main.py

for s in [signal.SIGINT, signal.SIGTERM]:
      loop.add_signal_handler(s, handle_exit)

isn't supported on windows.

If I change handle_exit() to:

def handle_exit():
    asyncio.ensure_future(shutdown(loop))

  if platform.system() != "Windows":
    for s in [signal.SIGINT, signal.SIGTERM]:
      loop.add_signal_handler(s, handle_exit)
  else:
    # On Windows, we can only reliably catch SIGINT (Ctrl+C)
    signal.signal(signal.SIGINT, lambda signum, frame: handle_exit())
  1. Getting some kind of network error between the GUI and the backend

It seems to work. It's insanely slow for me though (no GPU... the raspberry pis are much faster 😄). Windows changes perhaps out of scope for this PR though.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wowzah!

This is exciting. A lot of the community were mad we didn't support Windows. We have a bounty here if you want to get it working there once this one is merged: #186

@AlexCheema
Copy link
Contributor

Great work. You clearly thought about this and implemented a really nice solution. I particularly like the generalisation of model splitting, rather than doing each one separately.

Take a look through the comments I left.

@AlexCheema
Copy link
Contributor

The main thing I want to address and test is device support. We can make this the default inference engine if it works reliably across many devices.

On that point, if we can automate the bootstrapping of the environment for each user (e.g. install drivers, whatever else is needed to run on their device) that would be great. We don't have to do this in this PR/bounty, we can do another. But I would love to discuss and figure out how this can best be done.

@@ -0,0 +1,21 @@
import unittest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run this test in circle ci ./.circleci/config.yml

@@ -0,0 +1,33 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run this test in circle ci ./.circleci/config.yml

n_layers=12
)

engine = PyTorchDynamicShardInferenceEngine(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test complete? We need a test that tests the model splitting. Take a look at exo/inference/test_inference_engine.py. You can just add the test there.

Copy link
Author

@risingsunomi risingsunomi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed and looks good. Will work through notes to improve

…ention but not implemented fully, added RMSNorm from modeling llama on HF, added weight renaming and loading along with handling no lm_head weight in safetensor where you then use embed weight as seen with gpt2, still not generating proper reponses further dev being done
@risingsunomi
Copy link
Author

Finding building a pure pytorch implementation isn't working even from all the examples. Going to try to use the official meta code and hack it for the sharding we need. Will keep trying my method but not making much progress as able to shard the safetensors and everything but inference is not working at all. Still hitting at it.

Any other eyes to look at this would be appreciated. Right now, its in shambles but I am using the torchtune method as opposed to using fairscale. Think I might switch to fairscale though as the official meta llama model is looking better

My WIP code

Sorry again on the delay for this as regular job has me swamped but going to try to hit this faster before month is out.

Thank you again

@risingsunomi
Copy link
Author

Finding building a pure pytorch implementation isn't working even from all the examples. Going to try to use the official meta code and hack it for the sharding we need. Will keep trying my method but not making much progress as able to shard the safetensors and everything but inference is not working at all. Still hitting at it.

Any other eyes to look at this would be appreciated. Right now, its in shambles but I am using the torchtune method as opposed to using fairscale. Think I might switch to fairscale though as the official meta llama model is looking better

My WIP code

Sorry again on the delay for this as regular job has me swamped but going to try to hit this faster before month is out.

Thank you again

spoke too soon, think I can get this working
image

…ting and then inference engine testing but we are almost there. HELL YEAAAAAAAAAAA
…it and full test, separating huggingface and torchtune inference engines
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants