Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[multimodal] support specifying image_token, infering default if not … #2183

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def get_tokenizer_inputs(self, image_token="<image>"):

prompt_text = '\n'.join(texts)
if len(images) > 0:
# TODO: Find a reliable way to get the image token from tokenizer
prompt_text = f"{image_token}\n{prompt_text}"
return {
"role": self.role,
Expand Down
11 changes: 7 additions & 4 deletions engines/python/setup/djl_python/chat_completions/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict
from typing import Dict, Optional

from djl_python.chat_completions.chat_properties import ChatProperties

Expand All @@ -19,8 +19,10 @@ def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
tokenizer):
def parse_chat_completions_request(input_map: Dict,
is_rolling_batch: bool,
tokenizer,
image_token: Optional[str] = None):
if not is_rolling_batch:
raise ValueError(
"chat completions support is not currently available for dynamic batching. "
Expand All @@ -36,7 +38,8 @@ def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
images = []
tokenizer_inputs = []
for message in messages:
tokenizer_inputs.append(message.get_tokenizer_inputs())
tokenizer_inputs.append(
message.get_tokenizer_inputs(image_token=image_token))
images.extend(message.get_images())
inputs = tokenizer.apply_chat_template(tokenizer_inputs, tokenize=False)
param[
Expand Down
26 changes: 25 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def get_input_format_args(self):
"adapter_registry": self.adapter_registry,
"model_config": self.model_config,
"peft_config": self.peft_config,
"rolling_batch": self.rolling_batch
"rolling_batch": self.rolling_batch,
"image_placeholder_token": self.get_image_token(),
}

@staticmethod
Expand Down Expand Up @@ -460,6 +461,29 @@ def _read_model_config(self, model_config_path: str):
exc_info=True)
raise e

def get_image_token(self):
if self.hf_configs.image_placeholder_token:
return self.hf_configs.image_placeholder_token

logging.warning(
"image_placeholder_token is not explicitly set. It is highly recommended to explicitly"
"set the image_placeholder_token as it differs between models, and is not easy to infer from the model or tokenizer"
)

# TODO: Improve. We hardcode these for know model architectures as it is the most accurate and quickest way to set
# This is less than ideal, but until there is a good way to obtain this from the tokenizer/model, it's the best way to do so
model_type = self.model_config.model_type
if model_type == "phi3_v":
# phi3_v does support multiple images, but vllm/lmi-dist can only support 1 per request
return "<|image_1|>"
if model_type in {"llava", "llava_next", "paligemma"}:
return "<image>"

logging.warning(
"could not infer image token from the model artifacts. Using <image> as default."
)
return "<image>"


_service = HuggingFaceService()

Expand Down
6 changes: 5 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,13 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
input_map: Dict, **kwargs):
invoke_type = input_item.get_property("X-Amzn-SageMaker-Forwarded-Api")
tokenizer = kwargs.get("tokenizer")
image_token = kwargs.get("image_placeholder_token")
if is_chat_completions_request(input_map):
inputs, param = parse_chat_completions_request(
input_map, kwargs.get("is_rolling_batch"), tokenizer)
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token)
elif is_3p_request(invoke_type):
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class HuggingFaceProperties(Properties):
device: Optional[str] = None
kwargs: Optional[dict] = {}
data_type: Optional[str] = None
image_placeholder_token: Optional[str] = None

@field_validator('load_in_4bit')
def validate_load_in_4bit(cls, load_in_4bit):
Expand Down
Loading