Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dtype-based loading #461

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions libs/infinity_emb/infinity_emb/transformer/classifier/torch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeil

from infinity_emb._optional_imports import CHECK_TRANSFORMERS
from infinity_emb._optional_imports import CHECK_TRANSFORMERS, CHECK_TORCH
from infinity_emb.args import EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.transformer.abstract import BaseClassifer
from infinity_emb.transformer.acceleration import to_bettertransformer
from infinity_emb.transformer.quantization.interface import quant_interface
from infinity_emb.primitives import Device

if CHECK_TRANSFORMERS.is_available:
from transformers import AutoTokenizer, pipeline # type: ignore
if CHECK_TORCH.is_available:
import torch


class SentenceClassifier(BaseClassifer):
Expand All @@ -21,24 +25,37 @@ def __init__(
model_kwargs = {}
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"
ls = engine_args._loading_strategy
assert ls is not None

if ls.loading_dtype is not None: # type: ignore
model_kwargs["torch_dtype"] = ls.loading_dtype
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved

self._pipe = pipeline(
task="text-classification",
model=engine_args.model_name_or_path,
trust_remote_code=engine_args.trust_remote_code,
device=engine_args.device.resolve(),
device=ls.device_placement,
top_k=None,
revision=engine_args.revision,
model_kwargs=model_kwargs,
)
if self._pipe.device.type != "cpu": # and engine_args.dtype == "float16":
self._pipe.model = self._pipe.model.half()

self._pipe.model = to_bettertransformer(
self._pipe.model,
engine_args,
logger,
)

if ls.quantization_dtype is not None:
self._pipe.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement
self._pipe.model, engine_args.dtype, device=Device[self._pipe.model.device.type]
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
)

if engine_args.compile:
logger.info("using torch.compile(dynamic=True)")
self._pipe.model = torch.compile(self._pipe.model, dynamic=True)

self._infinity_tokenizer = AutoTokenizer.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
Expand Down
33 changes: 23 additions & 10 deletions libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH
from infinity_emb.args import EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.primitives import Dtype
from infinity_emb.primitives import Device
from infinity_emb.transformer.abstract import BaseCrossEncoder
from infinity_emb.transformer.quantization.interface import (
quant_interface,
)

if CHECK_TORCH.is_available and CHECK_SENTENCE_TRANSFORMERS.is_available:
import torch
Expand Down Expand Up @@ -42,14 +45,20 @@ def __init__(self, *, engine_args: EngineArgs):
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"

ls = engine_args._loading_strategy
assert ls is not None

if ls.loading_dtype is not None: # type: ignore
model_kwargs["torch_dtype"] = ls.loading_dtype

super().__init__(
engine_args.model_name_or_path,
revision=engine_args.revision,
device=engine_args.device.resolve(), # type: ignore
trust_remote_code=engine_args.trust_remote_code,
device=ls.device_placement,
automodel_args=model_kwargs,
)
self.model.to(self._target_device) # type: ignore
self.model.to(ls.device_placement)

# make a copy of the tokenizer,
# to be able to could the tokens in another thread
Expand All @@ -64,12 +73,16 @@ def __init__(self, *, engine_args: EngineArgs):
logger,
)

if self._target_device.type == "cuda" and engine_args.dtype in [
Dtype.auto,
Dtype.float16,
]:
logger.info("Switching to half() precision (cuda: fp16). ")
self.model.to(dtype=torch.float16)
self.model.to(ls.loading_dtype)
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved

if ls.quantization_dtype is not None:
self.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement
self.model, engine_args.dtype, device=Device[self.model.device.type]
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
)

if engine_args.compile:
logger.info("using torch.compile(dynamic=True)")
self.model = torch.compile(self.model, dynamic=True)

def encode_pre(self, input_tuples: list[tuple[str, str]]):
# return input_tuples
Expand All @@ -91,7 +104,7 @@ def encode_core(self, features: dict[str, "Tensor"]):
return out_features.detach().cpu()

def encode_post(self, out_features) -> list[float]:
return out_features.flatten()
return out_features.flatten().to(torch.float32).numpy()

def tokenize_lengths(self, sentences: list[str]) -> list[int]:
tks = self._infinity_tokenizer.batch_encode_plus(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ def __init__(self, *, engine_args=EngineArgs):
model_kwargs["attn_implementation"] = "eager"

ls = engine_args._loading_strategy
assert ls is not None

if ls.loading_dtype is not None:
model_kwargs["torch_dtype"] = ls.loading_dtype

super().__init__(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
device=ls.device_placement,
model_kwargs=model_kwargs,
# TODO: set torch_dtype=ls.loading_dtype to save memory on loading.
)
self.to(ls.device_placement)
# make a copy of the tokenizer,
Expand Down
Loading