From 63a4b45463e62df4653843747846652eb2e73ea0 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Wed, 12 Jun 2024 23:34:47 -0700 Subject: [PATCH] feat: added '__repr__' function for scalellm package (#242) --- scalellm/_C/llm_handler.pyi | 3 +++ scalellm/_C/output.pyi | 6 +++++ scalellm/_C/sampling_params.pyi | 1 + scalellm/csrc/llm_handler.cpp | 38 +++++++++++++++++++++++++--- scalellm/csrc/output.cpp | 41 ++++++++++++++++++++++++++----- scalellm/csrc/sampling_params.cpp | 26 +++++++++++++++++++- scalellm/llm.py | 7 ++++++ scalellm/llm_engine.py | 7 ++++++ src/handlers/llm_handler.h | 2 ++ 9 files changed, 121 insertions(+), 10 deletions(-) diff --git a/scalellm/_C/llm_handler.pyi b/scalellm/_C/llm_handler.pyi index d49cc067..301167c4 100644 --- a/scalellm/_C/llm_handler.pyi +++ b/scalellm/_C/llm_handler.pyi @@ -7,6 +7,7 @@ from scalellm._C.sampling_params import SamplingParams # Defined in csrc/llm_handler.cpp class Message: def __init__(self, role: str, content: str) -> None: ... + def __repr__(self) -> str: ... role: str content: str @@ -27,6 +28,7 @@ class BatchFuture: class LLMHandler: class Options: def __init__(self) -> None: ... + def __repr__(self) -> str: ... model_path: str devices: Optional[str] draft_model_path: Optional[str] @@ -45,6 +47,7 @@ class LLMHandler: num_handling_threads: int def __init__(self, options: Options) -> None: ... + def __repr__(self) -> str: ... def schedule_async( self, prompt: str, diff --git a/scalellm/_C/output.pyi b/scalellm/_C/output.pyi index d61614d5..742de273 100644 --- a/scalellm/_C/output.pyi +++ b/scalellm/_C/output.pyi @@ -4,12 +4,14 @@ from typing import List, Optional # Defined in csrc/output.cpp class Usage: def __init__(self) -> None: ... + def __repr__(self) -> str: ... num_prompt_tokens: int num_generated_tokens: int num_total_tokens: int class LogProbData: def __init__(self) -> None: ... + def __repr__(self) -> str: ... token: str token_id: int logprob: float @@ -17,6 +19,7 @@ class LogProbData: class LogProb: def __init__(self) -> None: ... + def __repr__(self) -> str: ... token: str token_id: int logprob: float @@ -25,6 +28,7 @@ class LogProb: class SequenceOutput: def __init__(self) -> None: ... + def __repr__(self) -> str: ... index: int text: str token_ids: List[int] @@ -33,6 +37,7 @@ class SequenceOutput: class RequestOutput: def __init__(self) -> None: ... + def __repr__(self) -> str: ... prompt: Optional[str] status: Optional[Status] outputs: List[SequenceOutput] @@ -52,6 +57,7 @@ class StatusCode(Enum): class Status: def __init__(self, code: StatusCode, message: str) -> None: ... + def __repr__(self) -> str: ... @property def code(self) -> StatusCode: ... @property diff --git a/scalellm/_C/sampling_params.pyi b/scalellm/_C/sampling_params.pyi index b1921043..7c6be1b4 100644 --- a/scalellm/_C/sampling_params.pyi +++ b/scalellm/_C/sampling_params.pyi @@ -21,6 +21,7 @@ class SamplingParams: stop: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None, ) -> None: ... + def __repr__(self) -> str: ... # number of tokens to generate. truncted to model's max context length. max_tokens: int # number of sequences to generate for each prompt. diff --git a/scalellm/csrc/llm_handler.cpp b/scalellm/csrc/llm_handler.cpp index 0f29fa3f..568b57e2 100644 --- a/scalellm/csrc/llm_handler.cpp +++ b/scalellm/csrc/llm_handler.cpp @@ -7,6 +7,7 @@ namespace llm::csrc { namespace py = pybind11; +using namespace pybind11::literals; void init_llm_handler(py::module_& m) { py::class_(m, "Message") @@ -14,7 +15,10 @@ void init_llm_handler(py::module_& m) { py::arg("role"), py::arg("content")) .def_readwrite("role", &Message::role) - .def_readwrite("content", &Message::content); + .def_readwrite("content", &Message::content) + .def("__repr__", [](const Message& self) { + return "Message({}: {!r})"_s.format(self.role, self.content); + }); py::enum_(m, "Priority") .value("DEFAULT", Priority::NORMAL) @@ -70,7 +74,10 @@ void init_llm_handler(py::module_& m) { py::call_guard()) .def("reset", &LLMHandler::reset, - py::call_guard()); + py::call_guard()) + .def("__repr__", [](const LLMHandler& self) { + return "LLMHandler({})"_s.format(self.options()); + }); // LLMHandler::Options py::class_(llm_handler, "Options") @@ -101,7 +108,32 @@ void init_llm_handler(py::module_& m) { .def_readwrite("num_speculative_tokens", &LLMHandler::Options::num_speculative_tokens_) .def_readwrite("num_handling_threads", - &LLMHandler::Options::num_handling_threads_); + &LLMHandler::Options::num_handling_threads_) + .def("__repr__", [](const LLMHandler::Options& self) { + return "Options(model_path={}, devices={}, draft_model_path={}, " + "draft_devices={}, block_size={}, max_cache_size={}, " + "max_memory_utilization={}, enable_prefix_cache={}, " + "enable_cuda_graph={}, cuda_graph_max_seq_len={}, " + "cuda_graph_batch_sizes={}, draft_cuda_graph_batch_sizes={}, " + "max_tokens_per_batch={}, max_seqs_per_batch={}, " + "num_speculative_tokens={}, num_handling_threads={})"_s.format( + self.model_path_, + self.devices_, + self.draft_model_path_, + self.draft_devices_, + self.block_size_, + self.max_cache_size_, + self.max_memory_utilization_, + self.enable_prefix_cache_, + self.enable_cuda_graph_, + self.cuda_graph_max_seq_len_, + self.cuda_graph_batch_sizes_, + self.draft_cuda_graph_batch_sizes_, + self.max_tokens_per_batch_, + self.max_seqs_per_batch_, + self.num_speculative_tokens_, + self.num_handling_threads_); + }); } } // namespace llm::csrc \ No newline at end of file diff --git a/scalellm/csrc/output.cpp b/scalellm/csrc/output.cpp index 8e9b3fdb..50af39e3 100644 --- a/scalellm/csrc/output.cpp +++ b/scalellm/csrc/output.cpp @@ -7,13 +7,20 @@ namespace llm::csrc { namespace py = pybind11; +using namespace pybind11::literals; void init_output(py::module_& m) { py::class_(m, "Usage") .def(py::init()) .def_readwrite("num_prompt_tokens", &Usage::num_prompt_tokens) .def_readwrite("num_generated_tokens", &Usage::num_generated_tokens) - .def_readwrite("num_total_tokens", &Usage::num_total_tokens); + .def_readwrite("num_total_tokens", &Usage::num_total_tokens) + .def("__repr__", [](const Usage& self) { + return "Usage(num_prompt_tokens={}, num_generated_tokens={}, num_total_tokens={})"_s + .format(self.num_prompt_tokens, + self.num_generated_tokens, + self.num_total_tokens); + }); py::enum_(m, "StatusCode") .value("OK", StatusCode::OK) @@ -33,14 +40,25 @@ void init_output(py::module_& m) { py::arg("message")) .def_property_readonly("code", &Status::code) .def_property_readonly("message", &Status::message) - .def_property_readonly("ok", &Status::ok); + .def_property_readonly("ok", &Status::ok) + .def("__repr__", [](const Status& self) { + if (self.message().empty()) { + return "Status(code={})"_s.format(self.code()); + } + return "Status(code={}, message={!r})"_s.format(self.code(), + self.message()); + }); py::class_(m, "LogProbData") .def(py::init()) .def_readwrite("token", &LogProbData::token) .def_readwrite("token_id", &LogProbData::token_id) .def_readwrite("logprob", &LogProbData::logprob) - .def_readwrite("finished_token", &LogProbData::finished_token); + .def_readwrite("finished_token", &LogProbData::finished_token) + .def("__repr__", [](const LogProbData& self) { + return "LogProb(token={!r}, logprob={})"_s.format(self.token, + self.logprob); + }); py::class_(m, "LogProb") .def(py::init()) @@ -48,7 +66,11 @@ void init_output(py::module_& m) { .def_readwrite("token_id", &LogProbData::token_id) .def_readwrite("logprob", &LogProbData::logprob) .def_readwrite("finished_token", &LogProbData::finished_token) - .def_readwrite("top_logprobs", &LogProb::top_logprobs); + .def_readwrite("top_logprobs", &LogProb::top_logprobs) + .def("__repr__", [](const LogProb& self) { + return "LogProb(token={!r}, logprob={}, top_logprobs={})"_s.format( + self.token, self.logprob, self.top_logprobs); + }); py::class_(m, "SequenceOutput") .def(py::init()) @@ -56,7 +78,10 @@ void init_output(py::module_& m) { .def_readwrite("text", &SequenceOutput::text) .def_readwrite("token_ids", &SequenceOutput::token_ids) .def_readwrite("finish_reason", &SequenceOutput::finish_reason) - .def_readwrite("logprobs", &SequenceOutput::logprobs); + .def_readwrite("logprobs", &SequenceOutput::logprobs) + .def("__repr__", [](const SequenceOutput& self) { + return "SequenceOutput({}: {!r})"_s.format(self.index, self.text); + }); py::class_(m, "RequestOutput") .def(py::init()) @@ -64,7 +89,11 @@ void init_output(py::module_& m) { .def_readwrite("status", &RequestOutput::status) .def_readwrite("outputs", &RequestOutput::outputs) .def_readwrite("usage", &RequestOutput::usage) - .def_readwrite("finished", &RequestOutput::finished); + .def_readwrite("finished", &RequestOutput::finished) + .def("__repr__", [](const RequestOutput& self) { + return "RequestOutput({}, {}, {})"_s.format( + self.outputs, self.status, self.usage); + }); } } // namespace llm::csrc \ No newline at end of file diff --git a/scalellm/csrc/sampling_params.cpp b/scalellm/csrc/sampling_params.cpp index fc7bf2b7..0c63a91b 100644 --- a/scalellm/csrc/sampling_params.cpp +++ b/scalellm/csrc/sampling_params.cpp @@ -5,6 +5,7 @@ namespace llm::csrc { namespace py = pybind11; +using namespace pybind11::literals; void init_sampling_params(py::module_& m) { // class SamplingParameter @@ -57,7 +58,30 @@ void init_sampling_params(py::module_& m) { &SamplingParams::skip_special_tokens) .def_readwrite("ignore_eos", &SamplingParams::ignore_eos) .def_readwrite("stop", &SamplingParams::stop) - .def_readwrite("stop_token_ids", &SamplingParams::stop_token_ids); + .def_readwrite("stop_token_ids", &SamplingParams::stop_token_ids) + .def("__repr__", [](const SamplingParams& self) { + return "SamplingParams(max_tokens={}, n={}, best_of={}, echo={}, " + "frequency_penalty={}, presence_penalty={}, " + "repetition_penalty={}, temperature={}, top_p={}, top_k={}, " + "logprobs={}, top_logprobs={}, skip_special_tokens={}, " + "ignore_eos={}, stop={}, stop_token_ids={})"_s.format( + self.max_tokens, + self.n, + self.best_of, + self.echo, + self.frequency_penalty, + self.presence_penalty, + self.repetition_penalty, + self.temperature, + self.top_p, + self.top_k, + self.logprobs, + self.top_logprobs, + self.skip_special_tokens, + self.ignore_eos, + self.stop, + self.stop_token_ids); + }); } } // namespace llm::csrc \ No newline at end of file diff --git a/scalellm/llm.py b/scalellm/llm.py index e6b422c2..f9207a4c 100644 --- a/scalellm/llm.py +++ b/scalellm/llm.py @@ -33,6 +33,8 @@ def __init__( num_handling_threads: int = 4, ) -> None: # download hf model if it does not exist + self._model = model + self._draft_model = draft_model model_path = model if not os.path.exists(model_path): model_path = download_hf_model( @@ -139,3 +141,8 @@ def __enter__(self): def __exit__(self, *args): self.__del__() return False + + def __repr__(self) -> str: + if self._draft_model: + return f"LLM(model={self._model}, draft_model={self._draft_model})" + return f"LLM(model={self._model})" diff --git a/scalellm/llm_engine.py b/scalellm/llm_engine.py index 06768aee..73a295b8 100644 --- a/scalellm/llm_engine.py +++ b/scalellm/llm_engine.py @@ -125,6 +125,8 @@ def __init__( num_speculative_tokens: int = 0, num_handling_threads: int = 4, ) -> None: + self._model = model + self._draft_model = draft_model # download hf model if it does not exist model_path = model if not os.path.exists(model_path): @@ -274,3 +276,8 @@ def __exit__(self, *args): self.stop() self.__del__() return False + + def __repr__(self) -> str: + if self._draft_model: + return f"AsyncLLMEngine(model={self._model}, draft_model={self._draft_model})" + return f"AsyncLLMEngine(model={self._model})" diff --git a/src/handlers/llm_handler.h b/src/handlers/llm_handler.h index b070b898..c0274489 100644 --- a/src/handlers/llm_handler.h +++ b/src/handlers/llm_handler.h @@ -157,6 +157,8 @@ class LLMHandler { // release underlying resources void reset(); + const Options& options() const { return options_; } + private: using Task = folly::Function; std::unique_ptr create_request(size_t tid,