Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lmi] remove redundant auto logic from python handler #2152

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@
"BloomModel": "text-generation",
}

LMI_DIST_ADV_MODEL = {
"RWForCausalLM",
"GPTNeoXForCausalLM",
"T5ForConditionalGeneration",
"LlamaForCausalLM",
"FalconForCausalLM",
"MPTForCausalLM",
"GPTBigCodeForCausalLM",
}

# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu
FLASH_2_SUPPORTED_MODELS = {
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
Expand All @@ -85,17 +75,8 @@ def enable_flash():
return False


def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
model_config):
if rolling_batch_type == "auto":
architecture = model_config.architectures[0]
if architecture in LMI_DIST_ADV_MODEL and is_mpi:
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
else:
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "scheduler":
def get_rolling_batch_class_from_str(rolling_batch_type: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have deprecation warning or a fallback for auto. There are still some customer setting that to auto today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"auto" is still valid, but we just handle it here instead https://github.com/deepjavalibrary/djl-serving/blob/master/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java#L81.

I don't think the logic here is ever invoked for "auto" in the current state.

if rolling_batch_type == "scheduler":
from djl_python.rolling_batch.scheduler_rolling_batch import SchedulerRollingBatch
return SchedulerRollingBatch
elif rolling_batch_type == "lmi-dist":
Expand Down Expand Up @@ -149,8 +130,7 @@ def initialize(self, properties: dict):

if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
_rolling_batch_cls = get_rolling_batch_class_from_str(
self.hf_configs.rolling_batch.value, self.hf_configs.mpi_mode,
self.model_config)
self.hf_configs.rolling_batch.value)
self.hf_configs.kwargs["model_config"] = self.model_config
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, properties,
Expand Down
6 changes: 2 additions & 4 deletions engines/python/setup/djl_python/tests/test_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
from djl_python import huggingface


def override_rolling_batch(rolling_batch_type: str, is_mpi: bool,
model_config):
def override_rolling_batch(rolling_batch_type: str):
from djl_python.tests.rolling_batch.fake_rolling_batch import FakeRollingBatch
return FakeRollingBatch


def override_rolling_batch_with_exception(rolling_batch_type: str,
is_mpi: bool, model_config):
def override_rolling_batch_with_exception(rolling_batch_type: str):
from djl_python.tests.rolling_batch.fake_rolling_batch import FakeRollingBatchWithException
return FakeRollingBatchWithException

Expand Down
Loading