diff --git a/engines/python/setup/djl_python/chat_completions/chat_properties.py b/engines/python/setup/djl_python/chat_completions/chat_properties.py index 1bd5db261..6e76c01e1 100644 --- a/engines/python/setup/djl_python/chat_completions/chat_properties.py +++ b/engines/python/setup/djl_python/chat_completions/chat_properties.py @@ -81,7 +81,6 @@ def get_tokenizer_inputs(self, image_token=""): prompt_text = '\n'.join(texts) if len(images) > 0: - # TODO: Find a reliable way to get the image token from tokenizer prompt_text = f"{image_token}\n{prompt_text}" return { "role": self.role, diff --git a/engines/python/setup/djl_python/chat_completions/chat_utils.py b/engines/python/setup/djl_python/chat_completions/chat_utils.py index 4509306d5..346e0ee0a 100644 --- a/engines/python/setup/djl_python/chat_completions/chat_utils.py +++ b/engines/python/setup/djl_python/chat_completions/chat_utils.py @@ -10,7 +10,7 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -from typing import Dict +from typing import Dict, Optional from djl_python.chat_completions.chat_properties import ChatProperties @@ -19,8 +19,10 @@ def is_chat_completions_request(inputs: Dict) -> bool: return "messages" in inputs -def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool, - tokenizer): +def parse_chat_completions_request(input_map: Dict, + is_rolling_batch: bool, + tokenizer, + image_token: Optional[str] = None): if not is_rolling_batch: raise ValueError( "chat completions support is not currently available for dynamic batching. " @@ -36,7 +38,8 @@ def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool, images = [] tokenizer_inputs = [] for message in messages: - tokenizer_inputs.append(message.get_tokenizer_inputs()) + tokenizer_inputs.append( + message.get_tokenizer_inputs(image_token=image_token)) images.extend(message.get_images()) inputs = tokenizer.apply_chat_template(tokenizer_inputs, tokenize=False) param[ diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 8e3ae70d9..e47d52a6f 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -163,7 +163,8 @@ def get_input_format_args(self): "adapter_registry": self.adapter_registry, "model_config": self.model_config, "peft_config": self.peft_config, - "rolling_batch": self.rolling_batch + "rolling_batch": self.rolling_batch, + "image_placeholder_token": self.get_image_token(), } @staticmethod @@ -460,6 +461,29 @@ def _read_model_config(self, model_config_path: str): exc_info=True) raise e + def get_image_token(self): + if self.hf_configs.image_placeholder_token: + return self.hf_configs.image_placeholder_token + + logging.warning( + "image_placeholder_token is not explicitly set. It is highly recommended to explicitly" + "set the image_placeholder_token as it differs between models, and is not easy to infer from the model or tokenizer" + ) + + # TODO: Improve. We hardcode these for know model architectures as it is the most accurate and quickest way to set + # This is less than ideal, but until there is a good way to obtain this from the tokenizer/model, it's the best way to do so + model_type = self.model_config.model_type + if model_type == "phi3_v": + # phi3_v does support multiple images, but vllm/lmi-dist can only support 1 per request + return "<|image_1|>" + if model_type in {"llava", "llava_next", "paligemma"}: + return "" + + logging.warning( + "could not infer image token from the model artifacts. Using as default." + ) + return "" + _service = HuggingFaceService() diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index b2ecaffb2..5c8201b7c 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -127,9 +127,13 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input, input_map: Dict, **kwargs): invoke_type = input_item.get_property("X-Amzn-SageMaker-Forwarded-Api") tokenizer = kwargs.get("tokenizer") + image_token = kwargs.get("image_placeholder_token") if is_chat_completions_request(input_map): inputs, param = parse_chat_completions_request( - input_map, kwargs.get("is_rolling_batch"), tokenizer) + input_map, + kwargs.get("is_rolling_batch"), + tokenizer, + image_token=image_token) elif is_3p_request(invoke_type): inputs, param = parse_3p_request(input_map, kwargs.get("is_rolling_batch"), diff --git a/engines/python/setup/djl_python/properties_manager/hf_properties.py b/engines/python/setup/djl_python/properties_manager/hf_properties.py index 8b3ab3e2e..c10e2d51d 100644 --- a/engines/python/setup/djl_python/properties_manager/hf_properties.py +++ b/engines/python/setup/djl_python/properties_manager/hf_properties.py @@ -61,6 +61,7 @@ class HuggingFaceProperties(Properties): device: Optional[str] = None kwargs: Optional[dict] = {} data_type: Optional[str] = None + image_placeholder_token: Optional[str] = None @field_validator('load_in_4bit') def validate_load_in_4bit(cls, load_in_4bit):