Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model integration for Lit-GPT #1792

Merged
merged 45 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
217e938
lit-gpt integration
aniketmaurya Aug 15, 2023
7bdceb5
update
aniketmaurya Aug 15, 2023
648248c
update
aniketmaurya Aug 15, 2023
a099f0c
remove top_k
aniketmaurya Sep 4, 2023
6e3c968
formatting
aniketmaurya Sep 4, 2023
529a4e7
implement logprob in a future PR
aniketmaurya Sep 4, 2023
4ece813
implement logprob in a future PR
aniketmaurya Sep 4, 2023
e5e73b5
LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG
aniketmaurya Sep 4, 2023
a7c583b
update logs
aniketmaurya Sep 4, 2023
716f6e4
allow temperature zero
aniketmaurya Sep 4, 2023
f6841ec
Merge branch 'main' into lit-gpt
aniketmaurya Sep 9, 2023
f35af30
fixes
aniketmaurya Sep 9, 2023
608c0fc
eos ids
aniketmaurya Sep 9, 2023
6a99819
format
aniketmaurya Sep 9, 2023
f30b69f
fix devices
aniketmaurya Sep 9, 2023
800e624
fix
aniketmaurya Sep 9, 2023
9aa8568
add stop words
aniketmaurya Sep 10, 2023
d7ef12f
fixes
aniketmaurya Sep 10, 2023
0b29328
singleton
aniketmaurya Sep 14, 2023
db95353
fix
aniketmaurya Sep 14, 2023
0df756e
merge main
aniketmaurya Sep 20, 2023
4b42ab0
formatting
aniketmaurya Sep 20, 2023
f300dd1
formatting
aniketmaurya Sep 20, 2023
c3fadb8
fix import
aniketmaurya Sep 20, 2023
aed82e8
apply suggestion
aniketmaurya Sep 24, 2023
4705abc
apply suggestion
aniketmaurya Sep 24, 2023
8b88d67
apply suggestion
aniketmaurya Sep 24, 2023
373bd8b
update lit-gpt
aniketmaurya Sep 24, 2023
966b5c6
update lit-gpt
aniketmaurya Sep 24, 2023
3cdc206
apply suggestions
aniketmaurya Sep 26, 2023
7a0756b
update
aniketmaurya Sep 26, 2023
c7da5d8
update
aniketmaurya Sep 26, 2023
28b46d7
update
aniketmaurya Sep 26, 2023
46ae293
format
aniketmaurya Sep 26, 2023
1a11550
format
aniketmaurya Sep 26, 2023
3e1ef09
black formatting
aniketmaurya Sep 26, 2023
d438939
black formatting
aniketmaurya Sep 26, 2023
8fa6516
fix ci
aniketmaurya Sep 26, 2023
7b3311e
Merge branch 'main' into lit-gpt
aniketmaurya Sep 26, 2023
ef080f2
fix typing
aniketmaurya Sep 26, 2023
d4d60a3
fix typing
aniketmaurya Sep 26, 2023
6864507
format
aniketmaurya Sep 26, 2023
0a6d561
fix typing
aniketmaurya Sep 26, 2023
399176e
fix
aniketmaurya Sep 26, 2023
678e584
fix type
aniketmaurya Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/helm/benchmark/static/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ models:
release_date: 2023-01-23
todo: true

# Lightning AI's Lit-GPT
- name: lightningai/lit-gpt
display_name: Lit-GPT
description: Lit-GPT is an optimized collection of open-source LLMs for finetuning and inference. It supports – Falcon, Llama 2, Vicuna, LongChat, and other top-performing open-source large language models.
creator_organization: Lightning AI
access: open
num_parameters: 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends on the chosen model from the config

release_date: 2023-04-04


# Meta
- name: together/opt-iml-175b
display_name: OPT-IML (175B)
Expand Down
27 changes: 27 additions & 0 deletions src/helm/benchmark/window_services/lit_gpt_window_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .local_window_service import LocalWindowService
from .tokenizer_service import TokenizerService


class LitGPTWindowServce(LocalWindowService):
def __init__(self, service: TokenizerService):
super().__init__(service)

@property
def max_sequence_length(self) -> int:
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
return 2048

@property
def max_request_length(self) -> int:
return self.max_sequence_length

@property
def end_of_text_token(self) -> str:
return "<|endoftext|>"

@property
def tokenizer_name(self) -> str:
return "lightningai/lit-gpt"

@property
def prefix_token(self) -> str:
return self.end_of_text_token
4 changes: 4 additions & 0 deletions src/helm/benchmark/window_services/window_service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .yalm_window_service import YaLMWindowService
from .llama_window_service import LlamaWindowService, Llama2WindowService
from .window_service import WindowService
from .lit_gpt_window_service import LitGPTWindowServce
from .tokenizer_service import TokenizerService
from helm.proxy.clients.huggingface_client import get_huggingface_model_config
from helm.proxy.clients.remote_model_registry import get_remote_model
Expand Down Expand Up @@ -222,6 +223,9 @@ def get_window_service(model_name: str, service: TokenizerService) -> WindowServ
)
else:
window_service = AI21WindowService(service=service, gpt2_window_service=GPT2WindowService(service))

elif organization == "lightningai":
window_service = LitGPTWindowServce(service)
else:
raise ValueError(f"Unhandled model name: {model_name}")

Expand Down
14 changes: 14 additions & 0 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ def _get_client(self, model: str) -> Client:
from helm.proxy.clients.megatron_client import MegatronClient

client = MegatronClient(cache_config=cache_config)

elif organization == "lightningai":
from helm.proxy.clients.lit_gpt_client import LitGPTClient

client = LitGPTClient(
cache_config=cache_config,
checkpoint_dir=os.environ.get("LIT_GPT_CHECKPOINT_DIR", ""),
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
precision=os.environ.get("LIT_GPT_PRECISION", "bf16-true")
)

else:
raise ValueError(f"Could not find client for model: {model}")
self.clients[model] = client
Expand Down Expand Up @@ -271,6 +281,10 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client:
from helm.proxy.clients.megatron_client import MegatronClient

client = MegatronClient(cache_config=cache_config)

elif organization == "lightningai":
client = self._get_client(tokenizer)

else:
raise ValueError(f"Could not find tokenizer client for model: {tokenizer}")
self.tokenizer_clients[tokenizer] = client
Expand Down
149 changes: 149 additions & 0 deletions src/helm/proxy/clients/lit_gpt_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import json
import time
import logging
from pathlib import Path
from typing import Dict, List, Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from lit_gpt import GPT, Config, Tokenizer
from lit_gpt.model import Block
from lit_gpt.utils import check_valid_checkpoint_dir, lazy_load, quantization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things needed here:

  1. Make lightning and lit_gpt optional imports by following the example here.
  2. Add the pip install dependencies to setup.cfg as optional extra dependencies. Follow the OpenAI example: create a new group called lit-gpt and add your dependencies under that group, and then add crfm-helm[lit-gpt] under models =

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added lit-gpt as requirement. Since lit-gpt doesn't come as a pip package so I have added it as lit-gpt @ git+https://github.com/Lightning-AI/lit-gpt@main. It might not be added to pypi in this form though.


from helm.common.cache import Cache, CacheConfig
from helm.common.request import Request, RequestResult, Sequence, Token
from helm.common.tokenization_request import (DecodeRequest,
DecodeRequestResult,
TokenizationRequest,
TokenizationRequestResult,
TokenizationToken)

from .client import Client, wrap_request_time
from .lit_gpt_generate import generate

yifanmai marked this conversation as resolved.
Show resolved Hide resolved
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


torch.set_float32_matmul_precision("high")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to do this in LitGPTClient.__init__() so that we don't have side effects from importing the module.



class LitGPTClient(Client):
"""Implements some "models" that just generate silly things quickly just to debug the infrastructure."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this docstring more useful for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the docstring


def __init__(
self,
cache_config: CacheConfig,
checkpoint_dir: str = "",
precision: str = "bf16-true",
device="auto",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whitespace nit: device = "auto"

devices: int = 1,
strategy: str = "auto",
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define a type name at the module's top level e.g.

Quantization = Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]

The type name can be Quantization or QuantizationSetting or something similar.

):
self.cache = Cache(cache_config)
if strategy == "fsdp":
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, accelerator=device, precision=precision, strategy=strategy)
fabric.launch()
logger.info("Using device: {}".format(fabric.device))

checkpoint_dir = Path(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir)

with open(checkpoint_dir / "lit_config.json") as fp:
config = Config(**json.load(fp))

checkpoint_path = checkpoint_dir / "lit_model.pth"
logger.info(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=True), quantization(quantize):
model = GPT(config)

with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint, strict=quantize is None)

model.eval()
self.model = fabric.setup(model)
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = Tokenizer(checkpoint_dir)
self.fabric = fabric

def make_request(self, request: Request) -> RequestResult:
model = self.model
tokenizer = self.tokenizer
fabric = self.fabric
encoded = tokenizer.encode(
request.prompt, bos=True, eos=False, device=fabric.device
)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + request.max_tokens
assert max_returned_tokens <= model.config.block_size, (
max_returned_tokens,
model.config.block_size,
) # maximum rope cache length

model.reset_cache()
t0 = time.perf_counter()
tokens, logprobs, top_logprobs = generate(
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
model,
encoded,
max_returned_tokens,
max_seq_length=max_returned_tokens,
temperature=max(request.temperature, 1e-9),
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
top_k=request.top_p,
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
)

t = time.perf_counter() - t0
model.reset_cache()
output = tokenizer.decode(tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this to right before completions = [Sequence(text=output, logprob=logprobs_sum, tokens=tokens)]

Also consider renaming output to output_text.

tokens_generated = tokens.size(0) - prompt_length
logger.info(
f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec"
)

logger.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be debug level, or removed entirely?

generated_tokens = []
for t, lp, tlp in zip(tokens, logprobs, top_logprobs):
idx, val = tlp
tok_str = tokenizer.processor.decode([idx])
token_tlp = {tok_str: val}
generated_tokens.append(
Token(text=tokenizer.decode(t), logprob=lp, top_logprobs=token_tlp)
)

logprobs_sum = sum(logprobs)
# Process the input data here
# response = dict(
# text=output, tokens=generated_tokens, logprob=logprobs_sum, request_time=t
# )

#
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove stray comment

tokens = generated_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't seem needed.

completions = [Sequence(text=output, logprob=logprobs_sum, tokens=tokens)]

return RequestResult(
success=True,
cached=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a # TODO: implement request caching after model names are configurable here.

error=None,
completions=completions,
embedding=[],
request_time=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set request_time.

)

def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
fabric = self.fabric
logger.info("Using device: {}".format(fabric.device))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove? We only need to know this during fabric initialization.

t0 = time.perf_counter()
encoded = self.tokenizer.encode(
request.text, bos=True, eos=False, device=fabric.device
)
t = time.perf_counter() - t0
tokens = encoded.tolist()
return TokenizationRequestResult(
success=True, cached=False, tokens=tokens, text=request.text,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a # TODO: implement request caching after allowing configurable model names here.

Also, please fix the indentation.

request_time=t
)


def decode(self, request: DecodeRequest) -> DecodeRequestResult:
raise NotImplementedError
79 changes: 79 additions & 0 deletions src/helm/proxy/clients/lit_gpt_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Optional, Tuple, List

import torch


@torch.no_grad()
def generate(
model: torch.nn.Module,
idx: torch.Tensor,
max_returned_tokens: int,
max_seq_length: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> Tuple[List[int], List[float], List[Tuple[int, float]]]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

The implementation of this function is modified from A. Karpathy's nanoGPT.

Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
max_seq_length: The maximum sequence length allowed. Should be less or equal than the block size.
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the <eos> token is triggered.

Returns:
Tuple containing a list of token indexes, id of the top log probability, and the actual log probability of the
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
selected token.
"""
T = idx.size(0)
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
assert max_returned_tokens > T
device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

top_logprob = []
logprob = []

# generate up to a fixed number of tokens
for _ in range(max_returned_tokens - T):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)

idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# append the logprob of selected token
logprob.append(torch.log(probs[idx_next]).item())

# append th idx and logprob of top token
top_logprob.append((torch.argmax(probs).item(), torch.log(probs).max().item()))
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

# advance
input_pos = input_pos[-1:] + 1

# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to handle stop sequences that are longer than one token, so we need to check a subsequence... maybe add a TODO comment here for now.

return idx[:input_pos], logprob, top_logprob # include the EOS token

return idx, logprob, top_logprob
7 changes: 7 additions & 0 deletions src/helm/proxy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,13 @@ def engine(self) -> str:
name="stabilityai/stablelm-base-alpha-7b",
tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG],
),

Model(
group="lightningai",
name="lightningai/lit-gpt",
tags=[TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, GPT2_TOKENIZER_TAG]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove GPT2_TOKENIZER_TAG and INSTRUCTION_FOLLOWING_MODEL_TAG (I assume that not all models use the GPT-2 tokenizer and are instruction following)

Also replace FULL_FUNCTIONALITY_TEXT_MODEL_TAG with LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG for now. We need logprobs, num_completions and top_k_per_token support for full functionality.

),

# For debugging
Model(
group="simple",
Expand Down