Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model integration for Lit-GPT #1792

Merged
merged 45 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
217e938
lit-gpt integration
aniketmaurya Aug 15, 2023
7bdceb5
update
aniketmaurya Aug 15, 2023
648248c
update
aniketmaurya Aug 15, 2023
a099f0c
remove top_k
aniketmaurya Sep 4, 2023
6e3c968
formatting
aniketmaurya Sep 4, 2023
529a4e7
implement logprob in a future PR
aniketmaurya Sep 4, 2023
4ece813
implement logprob in a future PR
aniketmaurya Sep 4, 2023
e5e73b5
LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG
aniketmaurya Sep 4, 2023
a7c583b
update logs
aniketmaurya Sep 4, 2023
716f6e4
allow temperature zero
aniketmaurya Sep 4, 2023
f6841ec
Merge branch 'main' into lit-gpt
aniketmaurya Sep 9, 2023
f35af30
fixes
aniketmaurya Sep 9, 2023
608c0fc
eos ids
aniketmaurya Sep 9, 2023
6a99819
format
aniketmaurya Sep 9, 2023
f30b69f
fix devices
aniketmaurya Sep 9, 2023
800e624
fix
aniketmaurya Sep 9, 2023
9aa8568
add stop words
aniketmaurya Sep 10, 2023
d7ef12f
fixes
aniketmaurya Sep 10, 2023
0b29328
singleton
aniketmaurya Sep 14, 2023
db95353
fix
aniketmaurya Sep 14, 2023
0df756e
merge main
aniketmaurya Sep 20, 2023
4b42ab0
formatting
aniketmaurya Sep 20, 2023
f300dd1
formatting
aniketmaurya Sep 20, 2023
c3fadb8
fix import
aniketmaurya Sep 20, 2023
aed82e8
apply suggestion
aniketmaurya Sep 24, 2023
4705abc
apply suggestion
aniketmaurya Sep 24, 2023
8b88d67
apply suggestion
aniketmaurya Sep 24, 2023
373bd8b
update lit-gpt
aniketmaurya Sep 24, 2023
966b5c6
update lit-gpt
aniketmaurya Sep 24, 2023
3cdc206
apply suggestions
aniketmaurya Sep 26, 2023
7a0756b
update
aniketmaurya Sep 26, 2023
c7da5d8
update
aniketmaurya Sep 26, 2023
28b46d7
update
aniketmaurya Sep 26, 2023
46ae293
format
aniketmaurya Sep 26, 2023
1a11550
format
aniketmaurya Sep 26, 2023
3e1ef09
black formatting
aniketmaurya Sep 26, 2023
d438939
black formatting
aniketmaurya Sep 26, 2023
8fa6516
fix ci
aniketmaurya Sep 26, 2023
7b3311e
Merge branch 'main' into lit-gpt
aniketmaurya Sep 26, 2023
ef080f2
fix typing
aniketmaurya Sep 26, 2023
d4d60a3
fix typing
aniketmaurya Sep 26, 2023
6864507
format
aniketmaurya Sep 26, 2023
0a6d561
fix typing
aniketmaurya Sep 26, 2023
399176e
fix
aniketmaurya Sep 26, 2023
678e584
fix type
aniketmaurya Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ openai =
openai~=0.27.8
tiktoken~=0.3.3

lit-gpt =
lit-gpt @ git+https://github.com/Lightning-AI/lit-gpt@main
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PyPI will reject packages with dependencies from GitHub, but we can deal with this in a later PR.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, no dependency outside pypi like source is accepted, you need to drop them


tsinghua =
icetk~=0.0.4

Expand All @@ -116,6 +119,7 @@ models =
crfm-helm[openai]
crfm-helm[tsinghua]
crfm-helm[yandex]
crfm-helm[lit-gpt]

cleva =
unidecode==1.3.6
Expand Down
10 changes: 10 additions & 0 deletions src/helm/benchmark/static/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,16 @@ models:
release_date: 2023-01-23
todo: true

# Lightning AI's Lit-GPT
- name: lightningai/lit-gpt
display_name: Lit-GPT
description: Lit-GPT is an optimized collection of open-source LLMs for finetuning and inference. It supports – Falcon, Llama 2, Vicuna, LongChat, and other top-performing open-source large language models.
creator_organization: Lightning AI
access: open
num_parameters: 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends on the chosen model from the config

release_date: 2023-04-04


# Meta
- name: together/opt-iml-175b
display_name: OPT-IML (175B)
Expand Down
27 changes: 27 additions & 0 deletions src/helm/benchmark/window_services/lit_gpt_window_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .local_window_service import LocalWindowService
from .tokenizer_service import TokenizerService


class LitGPTWindowServce(LocalWindowService):
def __init__(self, service: TokenizerService):
super().__init__(service)

@property
def max_sequence_length(self) -> int:
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
return 2048

@property
def max_request_length(self) -> int:
return self.max_sequence_length

@property
def end_of_text_token(self) -> str:
return "<|endoftext|>"

@property
def tokenizer_name(self) -> str:
return "lightningai/lit-gpt"

@property
def prefix_token(self) -> str:
return self.end_of_text_token
5 changes: 5 additions & 0 deletions src/helm/benchmark/window_services/window_service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ def get_window_service(model_name: str, service: TokenizerService) -> WindowServ
)
else:
window_service = AI21WindowService(service=service, gpt2_window_service=GPT2WindowService(service))

elif organization == "lightningai":
from helm.benchmark.window_services.lit_gpt_window_service import LitGPTWindowServce

window_service = LitGPTWindowServce(service)
else:
raise ValueError(f"Unhandled model name: {model_name}")

Expand Down
37 changes: 28 additions & 9 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import os
from dataclasses import replace
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional

from retrying import RetryError, Attempt
from retrying import Attempt, RetryError

from helm.benchmark.model_deployment_registry import get_model_deployment
from helm.common.cache import CacheConfig, MongoCacheConfig, SqliteCacheConfig
from helm.common.hierarchical_logger import hlog
from helm.common.object_spec import create_object
from helm.common.request import Request, RequestResult
from helm.common.tokenization_request import (
TokenizationRequest,
TokenizationRequestResult,
DecodeRequest,
DecodeRequestResult,
TokenizationRequest,
TokenizationRequestResult,
)
from helm.proxy.retry import retry_request, NonRetriableException
from helm.proxy.clients.critique_client import CritiqueClient
from helm.proxy.clients.client import Client
from .http_model_client import HTTPModelClient
from helm.proxy.clients.critique_client import CritiqueClient
from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_config
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient
from helm.proxy.retry import NonRetriableException, retry_request

from .http_model_client import HTTPModelClient

if TYPE_CHECKING:
import helm.proxy.clients.huggingface_client
Expand Down Expand Up @@ -174,6 +175,16 @@ def _get_client(self, model: str) -> Client:
from helm.proxy.clients.megatron_client import MegatronClient

client = MegatronClient(cache_config=cache_config)

elif organization == "lightningai":
from helm.proxy.clients.lit_gpt_client import LitGPTClient

client = LitGPTClient(
cache_config=cache_config,
checkpoint_dir=Path(os.environ.get("LIT_GPT_CHECKPOINT_DIR", "")),
precision=os.environ.get("LIT_GPT_PRECISION", "bf16-true"),
)

else:
raise ValueError(f"Could not find client for model: {model}")
self.clients[model] = client
Expand Down Expand Up @@ -273,6 +284,10 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client:
from helm.proxy.clients.megatron_client import MegatronClient

client = MegatronClient(cache_config=cache_config)

elif organization == "lightningai":
client = self._get_client(tokenizer)

else:
raise ValueError(f"Could not find tokenizer client for model: {tokenizer}")
self.tokenizer_clients[tokenizer] = client
Expand Down Expand Up @@ -327,11 +342,15 @@ def get_critique_client(self) -> CritiqueClient:

self._critique_client = RandomCritiqueClient()
elif critique_type == "mturk":
from helm.proxy.clients.mechanical_turk_critique_client import MechanicalTurkCritiqueClient
from helm.proxy.clients.mechanical_turk_critique_client import (
MechanicalTurkCritiqueClient,
)

self._critique_client = MechanicalTurkCritiqueClient()
elif critique_type == "surgeai":
from helm.proxy.clients.surge_ai_critique_client import SurgeAICritiqueClient
from helm.proxy.clients.surge_ai_critique_client import (
SurgeAICritiqueClient,
)

surgeai_credentials = self.credentials.get("surgeaiApiKey")
if not surgeai_credentials:
Expand Down
180 changes: 180 additions & 0 deletions src/helm/proxy/clients/lit_gpt_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import json
import logging
import time
from pathlib import Path
from threading import Lock
from typing import List, Literal, Optional, Dict, Union

import torch

from helm.common.cache import Cache, CacheConfig
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import Request, RequestResult, Sequence, Token
from helm.common.tokenization_request import (
DecodeRequest,
DecodeRequestResult,
TokenizationRequest,
TokenizationRequestResult,
TokenizationToken,
)

from .client import Client
from .lit_gpt_generate import generate # type: ignore

try:
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
from lit_gpt import GPT, Config, Tokenizer
from lit_gpt.model import Block
from lit_gpt.utils import check_valid_checkpoint_dir, lazy_load, quantization
except ModuleNotFoundError as e:
handle_module_not_found_error(e)

yifanmai marked this conversation as resolved.
Show resolved Hide resolved
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

QuantizationType = Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]


class SingletonMeta(type):
_instances: Dict[type, type] = {}
_lock: Lock = Lock()

def __call__(cls, *args, **kwargs):
with cls._lock:
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]


class LitGPT(metaclass=SingletonMeta):
def __init__(
self,
checkpoint_dir: Path = Path(""),
precision: str = "bf16-true",
device: str = "auto",
devices: int = 1,
strategy: Union[str, FSDPStrategy] = "auto",
quantize: Optional[QuantizationType] = None,
):
torch.set_float32_matmul_precision("high")

if strategy == "fsdp":
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, accelerator=device, precision=precision, strategy=strategy) # type: ignore
fabric.launch()
logger.info("Using device: {}".format(fabric.device))

check_valid_checkpoint_dir(checkpoint_dir)

with open(checkpoint_dir / "lit_config.json") as fp:
config = Config(**json.load(fp))

checkpoint_path = checkpoint_dir / "lit_model.pth"
logger.info(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=True), quantization(quantize):
model = GPT(config)

with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint, strict=quantize is None)

model.eval()
self.model = fabric.setup(model)
yifanmai marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = Tokenizer(checkpoint_dir)
self.fabric = fabric


class LitGPTClient(Client):
"""Client for evaluating Lit-GPT (from Lightning AI) supported LLMs"""

def __init__(
self,
cache_config: CacheConfig,
checkpoint_dir: Path = Path(""),
precision: str = "bf16-true",
device: str = "auto",
devices: int = 1,
strategy: str = "auto",
quantize: Optional[QuantizationType] = None,
):
self.cache = Cache(cache_config)
lit_gpt = LitGPT(checkpoint_dir, precision, device, devices, strategy, quantize)
self.model = lit_gpt.model
self.tokenizer = lit_gpt.tokenizer
self.fabric = lit_gpt.fabric

def make_request(self, request: Request) -> RequestResult:
model = self.model
tokenizer = self.tokenizer
fabric = self.fabric
encoded = tokenizer.encode(request.prompt, bos=True, eos=False, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens: int = prompt_length + request.max_tokens
assert max_returned_tokens <= model.config.block_size, (
max_returned_tokens,
model.config.block_size,
) # maximum rope cache length

model.clear_kv_cache()

with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens

t0 = time.perf_counter()
# helm doesn't have anything equivalent to top_k at the moment
# TODO: allow temperature=0, pick the top token rather than sampling.
stop_tokens: List[torch.Tensor] = [tokenizer.encode(e, device=fabric.device) for e in request.stop_sequences]

with fabric.init_tensor():
# enable the kv cache
model.set_kv_cache(batch_size=1)
tokens = generate(
model,
encoded,
max_returned_tokens,
temperature=max(request.temperature, 1e-11),
stop_tokens=stop_tokens,
)

t = time.perf_counter() - t0
model.clear_kv_cache()
if request.echo_prompt is False:
output = tokenizer.decode(tokens[prompt_length:])
else:
output = tokenizer.decode(tokens)
tokens_generated = tokens.size(0) - prompt_length

logger.debug(f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec")
logger.debug(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

generated_tokens = []
for token in tokens:
generated_tokens.append(Token(text=tokenizer.decode(token), logprob=0, top_logprobs={}))
completions = [Sequence(text=output, logprob=0, tokens=generated_tokens)]

return RequestResult(
success=True,
cached=False,
error=None,
completions=completions,
embedding=[],
request_time=t,
)

def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
fabric = self.fabric
logger.debug("Using device: {}".format(fabric.device))
t0 = time.perf_counter()
encoded = self.tokenizer.encode(request.text, bos=True, eos=False, device=fabric.device)
tokens = encoded.tolist()
tokens = [TokenizationToken(value=token) for token in tokens]
t = time.perf_counter() - t0
return TokenizationRequestResult(success=True, cached=False, tokens=tokens, text=request.text, request_time=t)

def decode(self, request: DecodeRequest) -> DecodeRequestResult:
t0 = time.perf_counter()
text = self.tokenizer.decode(torch.as_tensor(request.tokens, dtype=torch.int))
t = time.perf_counter() - t0
return DecodeRequestResult(success=True, cached=False, text=text, request_time=t)
Loading