diff --git a/src/lmql/models/lmtp/backends/transformers_model.py b/src/lmql/models/lmtp/backends/transformers_model.py index 37be1b1c..34569b54 100644 --- a/src/lmql/models/lmtp/backends/transformers_model.py +++ b/src/lmql/models/lmtp/backends/transformers_model.py @@ -39,6 +39,9 @@ def __init__(self, model_identifier, **kwargs): self.max_batch_size = kwargs.get("batch_size", 32) self.silent = kwargs.pop("silent", False) + self.torch_compile = kwargs.pop("torch_compile", False) + if self.torch_compile and self.loader != "transformers": + raise ValueError("Torch compile is only supported for transformers models") if not self.silent: print("[Loading", self.model_identifier, "with", self.model_constructor() + "]", flush=True) @@ -62,7 +65,9 @@ def __init__(self, model_identifier, **kwargs): else: from transformers import AutoModelForCausalLM self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, **self.model_args) - + if self.torch_compile: + self.model = torch.compile(self.model) + if self.loader == 'awq': self.device = self.model.model.device else: diff --git a/src/lmql/models/lmtp/lmtp_serve.py b/src/lmql/models/lmtp/lmtp_serve.py index 6b7a4b4f..785c86b9 100644 --- a/src/lmql/models/lmtp/lmtp_serve.py +++ b/src/lmql/models/lmtp/lmtp_serve.py @@ -87,7 +87,7 @@ def argparser(args): next_argument_name = None kwargs = {} - flag_args = ["cuda", "static", "single_thread", "docker_hide_port"] + flag_args = ["cuda", "static", "single_thread", "docker_hide_port", "torch_compile"] help_text = """ usage: serve-model [-h] [--port PORT] [--host HOST] [--cuda] [--dtype DTYPE] [--[*] VALUE] model @@ -110,6 +110,8 @@ def argparser(args): --single_thread Run the model on the main thread. This can lead to increased latency when processing multiple requests, but is necessary for some models that cannot be run in the background (e.g. falcon). + --torch_compile If set, the model will be compiled before serving. This can lead to a inferece speedup(~30%), c.f. https://huggingface.co/docs/transformers/main/perf_torch_compile + --dtype DTYPE What format to load the model weights in. Options: 'float16' (not available on all models), '8bit' (requires bitsandbytes), '4bit' (requires bitsandbytes). @@ -127,7 +129,6 @@ def argparser(args): --[*] VALUE Any other argument will be passed as a keyword argument to the relevant backend implementation, e.g. the AutoModelForCausalLM.from_pretrained function. """ - for arg in args: if arg == "-h" or arg == "--help": print(help_text)