Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]Add external LLM service support #25

Merged
merged 3 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ LLM Chat allows user interact with LLM to obtain a JSON-like structure. There ar
- `Omost Load Canvas Conditioning`: Load the JSON layout prompt previously saved

Optionally you can use the show-anything node to display the json text and save it for later.
The LLM runs slow. Each chat takes about 3~5min on 4090.
The official LLM's method runs slow. Each chat takes about 3~5min on 4090. (But now we can use TGI to deploy accelerated inference. For details, refer [**Accelerating LLM**](#accelerating-llm).)

Examples:
- Simple LLM Chat: ![image](https://github.com/huchenlei/ComfyUI_omost/assets/20929282/896eb810-6137-4682-8236-67cfefdbae99)
Expand Down Expand Up @@ -227,3 +227,39 @@ You can use the built-in region editor on `Omost Load Canvas Conditioning` node
### Compose with other control methods
You can freely compose the region condition with other control methods like ControlNet/IPAdapter. Following workflow applies an ipadapter model to the character region by selecting the corresponding mask.
![image](https://github.com/huchenlei/ComfyUI_omost/assets/20929282/191a5ea1-776a-42da-89ee-fd17a3a08eae)

### Accelerating LLM

Now you can leverage [TGI](https://huggingface.co/docs/text-generation-inference) to deploy LLM services and achieve up to 6x faster inference speeds. If you need long-term support for your work, this method is highly recommended to save you a lot of time.

**Preparation**: You will need an additional 20GB of VRAM to deploy an 8B LLM (trading space for time).

**First**, you can easily start the service using Docker with the following steps:
```
port=8080
modelID=lllyasviel/omost-llama-3-8b
memoryRate=0.9 # Normal operation requires 20GB of VRAM, adjust the ratio according to the VRAM of the deployment machine
volume=$HOME/.cache/huggingface/hub # Model cache files

docker run --gpus all -p $port:80 \
-v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $modelID --max-total-tokens 9216 --cuda-memory-fraction $memoryRate
```
Once the service is successfully started, you will see a Connected log message.

(Note: If you get stuck while downloading the model, try using a network proxy.)

**Then**, test if the LLM service has successfully started.
```
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Omost?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```

**Next**, add an `Omost LLM HTTP Server` node and enter the service address of the LLM.
![image](https://github.com/huchenlei/ComfyUI_omost/assets/6883957/8cf1f3a8-f4d7-416c-a1d0-be27bc300c96)


For more information about TGI, refer to the official documentation: https://huggingface.co/docs/text-generation-inference/quicktour
163 changes: 114 additions & 49 deletions omost_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import json
from typing import Literal, Tuple, TypedDict, NamedTuple
import sys
import os
import logging
from typing_extensions import NotRequired

import requests
from openai import OpenAI
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -55,6 +58,12 @@ class OmostLLM(NamedTuple):
tokenizer: AutoTokenizer


class OmostLLMServer(NamedTuple):
client: OpenAI
model_id: str
tokenizer: AutoTokenizer


ComfyUIConditioning = list # Dummy type definitions for ComfyUI
ComfyCLIPTokensWithWeight = list[Tuple[int, float]]

Expand Down Expand Up @@ -101,12 +110,47 @@ def load_llm(self, llm_name: str) -> Tuple[OmostLLM]:
torch_dtype=dtype, # This is computation type, not load/memory type. The loading quant type is baked in config.
token=HF_TOKEN,
device_map="auto", # This will load model to gpu with an offload system
trust_remote_code=True,
)
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, token=HF_TOKEN)

return (OmostLLM(llm_model, llm_tokenizer),)


class OmostLLMHTTPServerNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"address": ("STRING", {"multiline": True}),
}
}

RETURN_TYPES = ("OMOST_LLM",)
FUNCTION = "init_client"
CATEGORY = "omost"

def init_client(self, address: str) -> Tuple[OmostLLMServer]:
"""Initialize LLM client with HTTP server address."""
if address.endswith("v1"):
server_address = address
server_info_url = address.replace("v1", "info")
else:
server_address = os.path.join(address, "v1")
server_info_url = os.path.join(address, "info")

client = OpenAI(base_url=server_address, api_key="_")

# Get model_id from server info
server_info = requests.get(server_info_url, timeout=5).json()
model_id = server_info["model_id"]

# Load tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(model_id)

return (OmostLLMServer(client, model_id, llm_tokenizer), )


class OmostLLMChatNode:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -138,75 +182,94 @@ def INPUT_TYPES(s):
"OMOST_CONVERSATION",
"OMOST_CANVAS_CONDITIONING",
)
FUNCTION = "run_llm_with_seed"
FUNCTION = "run_llm"
CATEGORY = "omost"

def run_llm_with_seed(
def prepare_conversation(
self,
llm: OmostLLM,
text: str,
conversation: OmostConversation | None = None
) -> Tuple[OmostConversation, OmostConversation, OmostConversationItem]:
conversation = conversation or [] # Default to empty list
system_conversation_item: OmostConversationItem = {
"role": "system",
"content": system_prompt,
}
user_conversation_item: OmostConversationItem = {
"role": "user",
"content": text,
}
input_conversation: list[OmostConversationItem] = [
system_conversation_item,
*conversation,
user_conversation_item,
]
return conversation, input_conversation, user_conversation_item

def run_local_llm(
self,
llm: OmostLLM,
input_conversation: list[OmostConversationItem],
max_new_tokens: int,
top_p: float,
temperature: float,
seed: int,
conversation: OmostConversation | None = None,
) -> Tuple[OmostConversation, OmostCanvas]:
if seed > 0xFFFFFFFF:
seed = seed & 0xFFFFFFFF
logger.warning("Seed is too large. Truncating to 32-bit: %d", seed)

) -> str:
with scoped_torch_random(seed), scoped_numpy_random(seed):
return self.run_llm(
llm, text, max_new_tokens, top_p, temperature, conversation
llm_tokenizer: AutoTokenizer = llm.tokenizer
llm_model: AutoModelForCausalLM = llm.model

input_ids: torch.Tensor = llm_tokenizer.apply_chat_template(
input_conversation, return_tensors="pt", add_generation_prompt=True
).to(llm_model.device)
input_length = input_ids.shape[1]

output_ids: torch.Tensor = llm_model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature != 0,
)
generated_ids = output_ids[:, input_length:]
generated_text: str = llm_tokenizer.decode(
generated_ids[0],
skip_special_tokens=True,
skip_prompt=True,
timeout=10,
)
return generated_text

def run_llm(
self,
llm: OmostLLM,
llm: OmostLLM | OmostLLMServer,
text: str,
max_new_tokens: int,
top_p: float,
top_p: float,
temperature: float,
seed: int,
conversation: OmostConversation | None = None,
) -> Tuple[OmostConversation, OmostCanvas]:
"""Run LLM on text"""
llm_tokenizer: AutoTokenizer = llm.tokenizer
llm_model: AutoModelForCausalLM = llm.model

conversation = conversation or [] # Default to empty list
system_conversation_item: OmostConversationItem = {
"role": "system",
"content": system_prompt,
}
user_conversation_item: OmostConversationItem = {
"role": "user",
"content": text,
}
input_conversation: list[OmostConversationItem] = [
system_conversation_item,
*conversation,
user_conversation_item,
]
if seed > 0xFFFFFFFF:
seed = seed & 0xFFFFFFFF
logger.warning("Seed is too large. Truncating to 32-bit: %d", seed)

conversation, input_conversation, user_conversation_item = self.prepare_conversation(text, conversation)

input_ids: torch.Tensor = llm_tokenizer.apply_chat_template(
input_conversation, return_tensors="pt", add_generation_prompt=True
).to(llm_model.device)
input_length = input_ids.shape[1]

output_ids: torch.Tensor = llm_model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature != 0,
)
generated_ids = output_ids[:, input_length:]
generated_text: str = llm_tokenizer.decode(
generated_ids[0],
skip_special_tokens=True,
skip_prompt=True,
timeout=10,
)
if isinstance(llm, OmostLLM):
generated_text = self.run_local_llm(
llm, input_conversation, max_new_tokens, top_p, temperature, seed
)
else:
generated_text = llm.client.chat.completions.create(
model=llm.model_id,
messages=input_conversation,
top_p=top_p,
temperature=temperature,
max_tokens=max_new_tokens,
seed=seed,
).choices[0].message.content

output_conversation = [
*conversation,
Expand Down Expand Up @@ -496,6 +559,7 @@ def load_canvas(self, json_str: str) -> Tuple[list[OmostCanvasCondition]]:

NODE_CLASS_MAPPINGS = {
"OmostLLMLoaderNode": OmostLLMLoaderNode,
"OmostLLMHTTPServerNode": OmostLLMHTTPServerNode,
"OmostLLMChatNode": OmostLLMChatNode,
"OmostLayoutCondNode": OmostLayoutCondNode,
"OmostLoadCanvasConditioningNode": OmostLoadCanvasConditioningNode,
Expand All @@ -504,6 +568,7 @@ def load_canvas(self, json_str: str) -> Tuple[list[OmostCanvasCondition]]:

NODE_DISPLAY_NAME_MAPPINGS = {
"OmostLLMLoaderNode": "Omost LLM Loader",
"OmostLLMHTTPServerNode": "Omost LLM HTTP Server",
"OmostLLMChatNode": "Omost LLM Chat",
"OmostLayoutCondNode": "Omost Layout Cond (ComfyUI-Area)",
"OmostLoadCanvasConditioningNode": "Omost Load Canvas Conditioning",
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ transformers>=4.41.1
bitsandbytes>=0.43.1
protobuf>=3.20
torch
openai
requests