Skip to content

Commit

Permalink
feat: added '__repr__' function for scalellm package (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Jun 13, 2024
1 parent 4a03210 commit 63a4b45
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 10 deletions.
3 changes: 3 additions & 0 deletions scalellm/_C/llm_handler.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from scalellm._C.sampling_params import SamplingParams
# Defined in csrc/llm_handler.cpp
class Message:
def __init__(self, role: str, content: str) -> None: ...
def __repr__(self) -> str: ...
role: str
content: str

Expand All @@ -27,6 +28,7 @@ class BatchFuture:
class LLMHandler:
class Options:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
model_path: str
devices: Optional[str]
draft_model_path: Optional[str]
Expand All @@ -45,6 +47,7 @@ class LLMHandler:
num_handling_threads: int

def __init__(self, options: Options) -> None: ...
def __repr__(self) -> str: ...
def schedule_async(
self,
prompt: str,
Expand Down
6 changes: 6 additions & 0 deletions scalellm/_C/output.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ from typing import List, Optional
# Defined in csrc/output.cpp
class Usage:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
num_prompt_tokens: int
num_generated_tokens: int
num_total_tokens: int

class LogProbData:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
token: str
token_id: int
logprob: float
finished_token: bool

class LogProb:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
token: str
token_id: int
logprob: float
Expand All @@ -25,6 +28,7 @@ class LogProb:

class SequenceOutput:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
index: int
text: str
token_ids: List[int]
Expand All @@ -33,6 +37,7 @@ class SequenceOutput:

class RequestOutput:
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
prompt: Optional[str]
status: Optional[Status]
outputs: List[SequenceOutput]
Expand All @@ -52,6 +57,7 @@ class StatusCode(Enum):

class Status:
def __init__(self, code: StatusCode, message: str) -> None: ...
def __repr__(self) -> str: ...
@property
def code(self) -> StatusCode: ...
@property
Expand Down
1 change: 1 addition & 0 deletions scalellm/_C/sampling_params.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SamplingParams:
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> None: ...
def __repr__(self) -> str: ...
# number of tokens to generate. truncted to model's max context length.
max_tokens: int
# number of sequences to generate for each prompt.
Expand Down
38 changes: 35 additions & 3 deletions scalellm/csrc/llm_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@

namespace llm::csrc {
namespace py = pybind11;
using namespace pybind11::literals;

void init_llm_handler(py::module_& m) {
py::class_<Message>(m, "Message")
.def(py::init<const std::string&, const std::string&>(),
py::arg("role"),
py::arg("content"))
.def_readwrite("role", &Message::role)
.def_readwrite("content", &Message::content);
.def_readwrite("content", &Message::content)
.def("__repr__", [](const Message& self) {
return "Message({}: {!r})"_s.format(self.role, self.content);
});

py::enum_<Priority>(m, "Priority")
.value("DEFAULT", Priority::NORMAL)
Expand Down Expand Up @@ -70,7 +74,10 @@ void init_llm_handler(py::module_& m) {
py::call_guard<py::gil_scoped_release>())
.def("reset",
&LLMHandler::reset,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def("__repr__", [](const LLMHandler& self) {
return "LLMHandler({})"_s.format(self.options());
});

// LLMHandler::Options
py::class_<LLMHandler::Options>(llm_handler, "Options")
Expand Down Expand Up @@ -101,7 +108,32 @@ void init_llm_handler(py::module_& m) {
.def_readwrite("num_speculative_tokens",
&LLMHandler::Options::num_speculative_tokens_)
.def_readwrite("num_handling_threads",
&LLMHandler::Options::num_handling_threads_);
&LLMHandler::Options::num_handling_threads_)
.def("__repr__", [](const LLMHandler::Options& self) {
return "Options(model_path={}, devices={}, draft_model_path={}, "
"draft_devices={}, block_size={}, max_cache_size={}, "
"max_memory_utilization={}, enable_prefix_cache={}, "
"enable_cuda_graph={}, cuda_graph_max_seq_len={}, "
"cuda_graph_batch_sizes={}, draft_cuda_graph_batch_sizes={}, "
"max_tokens_per_batch={}, max_seqs_per_batch={}, "
"num_speculative_tokens={}, num_handling_threads={})"_s.format(
self.model_path_,
self.devices_,
self.draft_model_path_,
self.draft_devices_,
self.block_size_,
self.max_cache_size_,
self.max_memory_utilization_,
self.enable_prefix_cache_,
self.enable_cuda_graph_,
self.cuda_graph_max_seq_len_,
self.cuda_graph_batch_sizes_,
self.draft_cuda_graph_batch_sizes_,
self.max_tokens_per_batch_,
self.max_seqs_per_batch_,
self.num_speculative_tokens_,
self.num_handling_threads_);
});
}

} // namespace llm::csrc
41 changes: 35 additions & 6 deletions scalellm/csrc/output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@

namespace llm::csrc {
namespace py = pybind11;
using namespace pybind11::literals;

void init_output(py::module_& m) {
py::class_<Usage>(m, "Usage")
.def(py::init())
.def_readwrite("num_prompt_tokens", &Usage::num_prompt_tokens)
.def_readwrite("num_generated_tokens", &Usage::num_generated_tokens)
.def_readwrite("num_total_tokens", &Usage::num_total_tokens);
.def_readwrite("num_total_tokens", &Usage::num_total_tokens)
.def("__repr__", [](const Usage& self) {
return "Usage(num_prompt_tokens={}, num_generated_tokens={}, num_total_tokens={})"_s
.format(self.num_prompt_tokens,
self.num_generated_tokens,
self.num_total_tokens);
});

py::enum_<StatusCode>(m, "StatusCode")
.value("OK", StatusCode::OK)
Expand All @@ -33,38 +40,60 @@ void init_output(py::module_& m) {
py::arg("message"))
.def_property_readonly("code", &Status::code)
.def_property_readonly("message", &Status::message)
.def_property_readonly("ok", &Status::ok);
.def_property_readonly("ok", &Status::ok)
.def("__repr__", [](const Status& self) {
if (self.message().empty()) {
return "Status(code={})"_s.format(self.code());
}
return "Status(code={}, message={!r})"_s.format(self.code(),
self.message());
});

py::class_<LogProbData>(m, "LogProbData")
.def(py::init())
.def_readwrite("token", &LogProbData::token)
.def_readwrite("token_id", &LogProbData::token_id)
.def_readwrite("logprob", &LogProbData::logprob)
.def_readwrite("finished_token", &LogProbData::finished_token);
.def_readwrite("finished_token", &LogProbData::finished_token)
.def("__repr__", [](const LogProbData& self) {
return "LogProb(token={!r}, logprob={})"_s.format(self.token,
self.logprob);
});

py::class_<LogProb>(m, "LogProb")
.def(py::init())
.def_readwrite("token", &LogProbData::token)
.def_readwrite("token_id", &LogProbData::token_id)
.def_readwrite("logprob", &LogProbData::logprob)
.def_readwrite("finished_token", &LogProbData::finished_token)
.def_readwrite("top_logprobs", &LogProb::top_logprobs);
.def_readwrite("top_logprobs", &LogProb::top_logprobs)
.def("__repr__", [](const LogProb& self) {
return "LogProb(token={!r}, logprob={}, top_logprobs={})"_s.format(
self.token, self.logprob, self.top_logprobs);
});

py::class_<SequenceOutput>(m, "SequenceOutput")
.def(py::init())
.def_readwrite("index", &SequenceOutput::index)
.def_readwrite("text", &SequenceOutput::text)
.def_readwrite("token_ids", &SequenceOutput::token_ids)
.def_readwrite("finish_reason", &SequenceOutput::finish_reason)
.def_readwrite("logprobs", &SequenceOutput::logprobs);
.def_readwrite("logprobs", &SequenceOutput::logprobs)
.def("__repr__", [](const SequenceOutput& self) {
return "SequenceOutput({}: {!r})"_s.format(self.index, self.text);
});

py::class_<RequestOutput>(m, "RequestOutput")
.def(py::init())
.def_readwrite("prompt", &RequestOutput::prompt)
.def_readwrite("status", &RequestOutput::status)
.def_readwrite("outputs", &RequestOutput::outputs)
.def_readwrite("usage", &RequestOutput::usage)
.def_readwrite("finished", &RequestOutput::finished);
.def_readwrite("finished", &RequestOutput::finished)
.def("__repr__", [](const RequestOutput& self) {
return "RequestOutput({}, {}, {})"_s.format(
self.outputs, self.status, self.usage);
});
}

} // namespace llm::csrc
26 changes: 25 additions & 1 deletion scalellm/csrc/sampling_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

namespace llm::csrc {
namespace py = pybind11;
using namespace pybind11::literals;

void init_sampling_params(py::module_& m) {
// class SamplingParameter
Expand Down Expand Up @@ -57,7 +58,30 @@ void init_sampling_params(py::module_& m) {
&SamplingParams::skip_special_tokens)
.def_readwrite("ignore_eos", &SamplingParams::ignore_eos)
.def_readwrite("stop", &SamplingParams::stop)
.def_readwrite("stop_token_ids", &SamplingParams::stop_token_ids);
.def_readwrite("stop_token_ids", &SamplingParams::stop_token_ids)
.def("__repr__", [](const SamplingParams& self) {
return "SamplingParams(max_tokens={}, n={}, best_of={}, echo={}, "
"frequency_penalty={}, presence_penalty={}, "
"repetition_penalty={}, temperature={}, top_p={}, top_k={}, "
"logprobs={}, top_logprobs={}, skip_special_tokens={}, "
"ignore_eos={}, stop={}, stop_token_ids={})"_s.format(
self.max_tokens,
self.n,
self.best_of,
self.echo,
self.frequency_penalty,
self.presence_penalty,
self.repetition_penalty,
self.temperature,
self.top_p,
self.top_k,
self.logprobs,
self.top_logprobs,
self.skip_special_tokens,
self.ignore_eos,
self.stop,
self.stop_token_ids);
});
}

} // namespace llm::csrc
7 changes: 7 additions & 0 deletions scalellm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
num_handling_threads: int = 4,
) -> None:
# download hf model if it does not exist
self._model = model
self._draft_model = draft_model
model_path = model
if not os.path.exists(model_path):
model_path = download_hf_model(
Expand Down Expand Up @@ -139,3 +141,8 @@ def __enter__(self):
def __exit__(self, *args):
self.__del__()
return False

def __repr__(self) -> str:
if self._draft_model:
return f"LLM(model={self._model}, draft_model={self._draft_model})"
return f"LLM(model={self._model})"
7 changes: 7 additions & 0 deletions scalellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
num_speculative_tokens: int = 0,
num_handling_threads: int = 4,
) -> None:
self._model = model
self._draft_model = draft_model
# download hf model if it does not exist
model_path = model
if not os.path.exists(model_path):
Expand Down Expand Up @@ -274,3 +276,8 @@ def __exit__(self, *args):
self.stop()
self.__del__()
return False

def __repr__(self) -> str:
if self._draft_model:
return f"AsyncLLMEngine(model={self._model}, draft_model={self._draft_model})"
return f"AsyncLLMEngine(model={self._model})"
2 changes: 2 additions & 0 deletions src/handlers/llm_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class LLMHandler {
// release underlying resources
void reset();

const Options& options() const { return options_; }

private:
using Task = folly::Function<void(size_t tid)>;
std::unique_ptr<Request> create_request(size_t tid,
Expand Down

0 comments on commit 63a4b45

Please sign in to comment.