-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model integration for Lit-GPT #1792
Changes from 3 commits
217e938
7bdceb5
648248c
a099f0c
6e3c968
529a4e7
4ece813
e5e73b5
a7c583b
716f6e4
f6841ec
f35af30
608c0fc
6a99819
f30b69f
800e624
9aa8568
d7ef12f
0b29328
db95353
0df756e
4b42ab0
f300dd1
c3fadb8
aed82e8
4705abc
8b88d67
373bd8b
966b5c6
3cdc206
7a0756b
c7da5d8
28b46d7
46ae293
1a11550
3e1ef09
d438939
8fa6516
7b3311e
ef080f2
d4d60a3
6864507
0a6d561
399176e
678e584
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from .local_window_service import LocalWindowService | ||
from .tokenizer_service import TokenizerService | ||
|
||
|
||
class LitGPTWindowServce(LocalWindowService): | ||
def __init__(self, service: TokenizerService): | ||
super().__init__(service) | ||
|
||
@property | ||
def max_sequence_length(self) -> int: | ||
yifanmai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return 2048 | ||
|
||
@property | ||
def max_request_length(self) -> int: | ||
return self.max_sequence_length | ||
|
||
@property | ||
def end_of_text_token(self) -> str: | ||
return "<|endoftext|>" | ||
|
||
@property | ||
def tokenizer_name(self) -> str: | ||
return "lightningai/lit-gpt" | ||
|
||
@property | ||
def prefix_token(self) -> str: | ||
return self.end_of_text_token |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import json | ||
import time | ||
import logging | ||
from pathlib import Path | ||
from typing import Dict, List, Literal, Optional | ||
|
||
import lightning as L | ||
import torch | ||
from lightning.fabric.strategies import FSDPStrategy | ||
from lit_gpt import GPT, Config, Tokenizer | ||
from lit_gpt.model import Block | ||
from lit_gpt.utils import check_valid_checkpoint_dir, lazy_load, quantization | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two things needed here:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added lit-gpt as requirement. Since lit-gpt doesn't come as a pip package so I have added it as |
||
|
||
from helm.common.cache import Cache, CacheConfig | ||
from helm.common.request import Request, RequestResult, Sequence, Token | ||
from helm.common.tokenization_request import (DecodeRequest, | ||
DecodeRequestResult, | ||
TokenizationRequest, | ||
TokenizationRequestResult, | ||
TokenizationToken) | ||
|
||
from .client import Client, wrap_request_time | ||
from .lit_gpt_generate import generate | ||
|
||
yifanmai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
torch.set_float32_matmul_precision("high") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to do this in |
||
|
||
|
||
class LitGPTClient(Client): | ||
"""Implements some "models" that just generate silly things quickly just to debug the infrastructure.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this docstring more useful for users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated the docstring |
||
|
||
def __init__( | ||
self, | ||
cache_config: CacheConfig, | ||
checkpoint_dir: str = "", | ||
precision: str = "bf16-true", | ||
device="auto", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whitespace nit: |
||
devices: int = 1, | ||
strategy: str = "auto", | ||
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Define a type name at the module's top level e.g.
The type name can be |
||
): | ||
self.cache = Cache(cache_config) | ||
if strategy == "fsdp": | ||
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) | ||
fabric = L.Fabric(devices=devices, accelerator=device, precision=precision, strategy=strategy) | ||
fabric.launch() | ||
logger.info("Using device: {}".format(fabric.device)) | ||
|
||
checkpoint_dir = Path(checkpoint_dir) | ||
check_valid_checkpoint_dir(checkpoint_dir) | ||
|
||
with open(checkpoint_dir / "lit_config.json") as fp: | ||
config = Config(**json.load(fp)) | ||
|
||
checkpoint_path = checkpoint_dir / "lit_model.pth" | ||
logger.info(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") | ||
with fabric.init_module(empty_init=True), quantization(quantize): | ||
model = GPT(config) | ||
|
||
with lazy_load(checkpoint_path) as checkpoint: | ||
model.load_state_dict(checkpoint, strict=quantize is None) | ||
|
||
model.eval() | ||
self.model = fabric.setup(model) | ||
yifanmai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.tokenizer = Tokenizer(checkpoint_dir) | ||
self.fabric = fabric | ||
|
||
def make_request(self, request: Request) -> RequestResult: | ||
model = self.model | ||
tokenizer = self.tokenizer | ||
fabric = self.fabric | ||
encoded = tokenizer.encode( | ||
request.prompt, bos=True, eos=False, device=fabric.device | ||
) | ||
prompt_length = encoded.size(0) | ||
max_returned_tokens = prompt_length + request.max_tokens | ||
assert max_returned_tokens <= model.config.block_size, ( | ||
max_returned_tokens, | ||
model.config.block_size, | ||
) # maximum rope cache length | ||
|
||
model.reset_cache() | ||
t0 = time.perf_counter() | ||
tokens, logprobs, top_logprobs = generate( | ||
yifanmai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model, | ||
encoded, | ||
max_returned_tokens, | ||
max_seq_length=max_returned_tokens, | ||
temperature=max(request.temperature, 1e-9), | ||
aniketmaurya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
top_k=request.top_p, | ||
aniketmaurya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
t = time.perf_counter() - t0 | ||
model.reset_cache() | ||
output = tokenizer.decode(tokens) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: move this to right before Also consider renaming |
||
tokens_generated = tokens.size(0) - prompt_length | ||
logger.info( | ||
f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec" | ||
) | ||
|
||
logger.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these be debug level, or removed entirely? |
||
generated_tokens = [] | ||
for t, lp, tlp in zip(tokens, logprobs, top_logprobs): | ||
idx, val = tlp | ||
tok_str = tokenizer.processor.decode([idx]) | ||
token_tlp = {tok_str: val} | ||
generated_tokens.append( | ||
Token(text=tokenizer.decode(t), logprob=lp, top_logprobs=token_tlp) | ||
) | ||
|
||
logprobs_sum = sum(logprobs) | ||
# Process the input data here | ||
# response = dict( | ||
# text=output, tokens=generated_tokens, logprob=logprobs_sum, request_time=t | ||
# ) | ||
|
||
# | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove stray comment |
||
tokens = generated_tokens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line doesn't seem needed. |
||
completions = [Sequence(text=output, logprob=logprobs_sum, tokens=tokens)] | ||
|
||
return RequestResult( | ||
success=True, | ||
cached=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a |
||
error=None, | ||
completions=completions, | ||
embedding=[], | ||
request_time=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Set |
||
) | ||
|
||
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult: | ||
fabric = self.fabric | ||
logger.info("Using device: {}".format(fabric.device)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? We only need to know this during fabric initialization. |
||
t0 = time.perf_counter() | ||
encoded = self.tokenizer.encode( | ||
request.text, bos=True, eos=False, device=fabric.device | ||
) | ||
t = time.perf_counter() - t0 | ||
tokens = encoded.tolist() | ||
return TokenizationRequestResult( | ||
success=True, cached=False, tokens=tokens, text=request.text, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add a Also, please fix the indentation. |
||
request_time=t | ||
) | ||
|
||
|
||
def decode(self, request: DecodeRequest) -> DecodeRequestResult: | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Optional, Tuple, List | ||
|
||
import torch | ||
|
||
|
||
@torch.no_grad() | ||
def generate( | ||
model: torch.nn.Module, | ||
idx: torch.Tensor, | ||
max_returned_tokens: int, | ||
max_seq_length: int, | ||
*, | ||
temperature: float = 1.0, | ||
top_k: Optional[int] = None, | ||
eos_id: Optional[int] = None, | ||
) -> Tuple[List[int], List[float], List[Tuple[int, float]]]: | ||
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | ||
|
||
The implementation of this function is modified from A. Karpathy's nanoGPT. | ||
|
||
Args: | ||
model: The model to use. | ||
idx: Tensor of shape (T) with indices of the prompt sequence. | ||
max_returned_tokens: The maximum number of tokens to return (given plus generated). | ||
max_seq_length: The maximum sequence length allowed. Should be less or equal than the block size. | ||
temperature: Scales the predicted logits by 1 / temperature. | ||
top_k: If specified, only sample among the tokens with the k highest probabilities. | ||
eos_id: If specified, stop generating any more token once the <eos> token is triggered. | ||
|
||
Returns: | ||
Tuple containing a list of token indexes, id of the top log probability, and the actual log probability of the | ||
yifanmai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
selected token. | ||
""" | ||
T = idx.size(0) | ||
yifanmai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert max_returned_tokens > T | ||
device, dtype = idx.device, idx.dtype | ||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
empty = torch.empty(max_returned_tokens, dtype=dtype, device=device) | ||
empty[:T] = idx | ||
idx = empty | ||
input_pos = torch.arange(0, T, device=device) | ||
|
||
top_logprob = [] | ||
logprob = [] | ||
|
||
# generate up to a fixed number of tokens | ||
for _ in range(max_returned_tokens - T): | ||
x = idx.index_select(0, input_pos).view(1, -1) | ||
|
||
# forward | ||
logits = model(x, max_seq_length, input_pos) | ||
logits = logits[0, -1] / temperature | ||
aniketmaurya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# optionally crop the logits to only the top k options | ||
if top_k is not None: | ||
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | ||
logits = torch.where(logits < v[[-1]], -float("Inf"), logits) | ||
|
||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
|
||
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) | ||
|
||
# append the logprob of selected token | ||
logprob.append(torch.log(probs[idx_next]).item()) | ||
|
||
# append th idx and logprob of top token | ||
top_logprob.append((torch.argmax(probs).item(), torch.log(probs).max().item())) | ||
aniketmaurya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# advance | ||
input_pos = input_pos[-1:] + 1 | ||
|
||
# concatenate the new generation | ||
idx = idx.index_copy(0, input_pos, idx_next) | ||
|
||
# if <eos> token is triggered, return the output (stop generation) | ||
if idx_next == eos_id: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might need to handle stop sequences that are longer than one token, so we need to check a subsequence... maybe add a TODO comment here for now. |
||
return idx[:input_pos], logprob, top_logprob # include the EOS token | ||
|
||
return idx, logprob, top_logprob |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -833,6 +833,13 @@ def engine(self) -> str: | |
name="stabilityai/stablelm-base-alpha-7b", | ||
tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG], | ||
), | ||
|
||
Model( | ||
group="lightningai", | ||
name="lightningai/lit-gpt", | ||
tags=[TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, GPT2_TOKENIZER_TAG] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove Also replace |
||
), | ||
|
||
# For debugging | ||
Model( | ||
group="simple", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
depends on the chosen model from the config