Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 migration #423

Merged
merged 11 commits into from
Aug 22, 2024
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
asyncio
autodoc_pydantic<2.0.0
autodoc_pydantic>=2.0.0
deepspeed>=0.13.0
loadams marked this conversation as resolved.
Show resolved Hide resolved
grpcio
grpcio-tools
Expand Down
10 changes: 5 additions & 5 deletions mii/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _parse_kwargs_to_model_config(
# Fill model_config dict with relevant kwargs, store remaining kwargs in a new dict
remaining_kwargs = {}
for key, val in kwargs.items():
if key in ModelConfig.__dict__["__fields__"]:
if key in ModelConfig.model_fields.keys():
if key in model_config:
assert (
model_config.get(key) == val
Expand Down Expand Up @@ -77,7 +77,7 @@ def _parse_kwargs_to_mii_config(

# Fill mii_config dict with relevant kwargs, raise error on unknown kwargs
for key, val in remaining_kwargs.items():
if key in MIIConfig.__dict__["__fields__"]:
if key in MIIConfig.model_fields.keys():
if key in mii_config:
assert (
mii_config.get(key) == val
Expand Down Expand Up @@ -183,9 +183,9 @@ def serve(
mii.aml_related.utils.generate_aml_scripts(
acr_name=acr_name,
deployment_name=mii_config.deployment_name,
model_name=mii_config.model_config.model,
task_name=mii_config.model_config.task,
replica_num=mii_config.model_config.replica_num,
model_name=mii_config.model_conf.model,
task_name=mii_config.model_conf.task,
replica_num=mii_config.model_conf.replica_num,
instance_type=mii_config.instance_type,
version=mii_config.version,
)
Expand Down
2 changes: 1 addition & 1 deletion mii/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MIIClient:
"""
def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
self.mii_config = mii_config
self.task = mii_config.model_config.task
self.task = mii_config.model_conf.task
self.port = mii_config.port_number
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, self.port)
Expand Down
19 changes: 9 additions & 10 deletions mii/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def config_to_b64_str(config: DeepSpeedConfigModel) -> str:
# convert json str -> bytes
json_bytes = config.json().encode()
json_bytes = config.model_dump_json().encode()
# base64 encoded bytes
b64_config_bytes = base64.urlsafe_b64encode(json_bytes)
# bytes -> str
Expand All @@ -31,7 +31,7 @@ class MIIServer:
"""Initialize the model, setup the server for the model"""
def __init__(self, mii_config: MIIConfig) -> None:

self.task = mii_config.model_config.task
self.task = mii_config.model_conf.task
self.port_number = mii_config.port_number

if not os.path.isfile(mii_config.hostfile):
Expand All @@ -47,8 +47,7 @@ def __init__(self, mii_config: MIIConfig) -> None:
# balancer process, each DeepSpeed model replica, and optionally the
# REST API process)
processes = self._initialize_service(mii_config)
self._wait_until_server_is_live(processes,
mii_config.model_config.replica_configs)
self._wait_until_server_is_live(processes, mii_config.model_conf.replica_configs)

def _wait_until_server_is_live(self,
processes: List[subprocess.Popen],
Expand Down Expand Up @@ -143,15 +142,15 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
]

host_gpus = defaultdict(list)
for repl_config in mii_config.model_config.replica_configs:
for repl_config in mii_config.model_conf.replica_configs:
host_gpus[repl_config.hostname].extend(repl_config.gpu_indices)

use_multiple_hosts = len(
set(repl_config.hostname
for repl_config in mii_config.model_config.replica_configs)) > 1
for repl_config in mii_config.model_conf.replica_configs)) > 1

# Start replica instances
for repl_config in mii_config.model_config.replica_configs:
for repl_config in mii_config.model_conf.replica_configs:
hostfile = tempfile.NamedTemporaryFile(delete=False)
hostfile.write(
f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n"
Expand All @@ -161,7 +160,7 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
use_multiple_hosts)
processes.append(
self._launch_server_process(
mii_config.model_config,
mii_config.model_conf,
"MII server",
ds_launch_str=ds_launch_str,
server_args=server_args + [
Expand All @@ -175,15 +174,15 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
# expected to assign one GPU to one process.
processes.append(
self._launch_server_process(
mii_config.model_config,
mii_config.model_conf,
"load balancer",
server_args=server_args + ["--load-balancer"],
))

if mii_config.enable_restful_api:
processes.append(
self._launch_server_process(
mii_config.model_config,
mii_config.model_conf,
"restful api gateway",
server_args=server_args + ["--restful-gateway"],
))
Expand Down
125 changes: 58 additions & 67 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,18 @@

from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile
from deepspeed.inference import RaggedInferenceEngineConfig
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from pydantic import Field, model_validator, field_validator

from mii.constants import DeploymentType, TaskType, ModelProvider
from mii.errors import DeploymentNotFoundError
from mii.modeling.tokenizers import MIITokenizerWrapper
from mii.pydantic_v1 import BaseModel, Field, root_validator, validator, Extra
from mii.utils import generate_deployment_name, get_default_task, import_score_file
from mii.utils import generate_deployment_name, import_score_file

DEVICE_MAP_DEFAULT = "auto"


class MIIConfigModel(BaseModel):
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
allow_population_by_field_name = True
extra = "forbid"
arbitrary_types_allowed = True


class GenerateParamsConfig(MIIConfigModel):
class GenerateParamsConfig(DeepSpeedConfigModel):
"""
Options for changing text-generation behavior.
"""
Expand All @@ -39,7 +30,7 @@ class GenerateParamsConfig(MIIConfigModel):
max_length: int = 1024
""" Maximum length of ``input_tokens`` + ``generated_tokens``. """

max_new_tokens: int = None
max_new_tokens: Optional[int] = None
""" Maximum number of new tokens generated. ``max_length`` takes precedent. """

min_new_tokens: int = 0
Expand Down Expand Up @@ -68,24 +59,25 @@ class GenerateParamsConfig(MIIConfigModel):

stop: List[str] = []
""" List of strings to stop generation at."""
@validator("stop", pre=True)
@field_validator("stop", mode="before")
@classmethod
def make_stop_string_list(cls, field_value: Union[str, List[str]]) -> List[str]:
if isinstance(field_value, str):
return [field_value]
return field_value

@validator("stop")
@field_validator("stop")
@classmethod
def sort_stop_strings(cls, field_value: List[str]) -> List[str]:
return sorted(field_value)

@root_validator
def check_prompt_length(cls, values: Dict[str, Any]) -> Dict[str, Any]:
prompt_length = values.get("prompt_length")
max_length = values.get("max_length")
assert max_length > prompt_length, f"max_length ({max_length}) must be greater than prompt_length ({prompt_length})"
return values
@model_validator(mode="after")
def check_prompt_length(self) -> "GenerateParamsConfig":
assert self.max_length > self.prompt_length, f"max_length ({self.max_length}) must be greater than prompt_length ({self.prompt_length})"
return self

@root_validator
@model_validator(mode="before")
@classmethod
def set_max_new_tokens(cls, values: Dict[str, Any]) -> Dict[str, Any]:
max_length = values.get("max_length")
max_new_tokens = values.get("max_new_tokens")
Expand All @@ -94,19 +86,16 @@ def set_max_new_tokens(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["max_new_tokens"] = max_length - prompt_length
return values

class Config:
extra = Extra.forbid


class ReplicaConfig(MIIConfigModel):
class ReplicaConfig(DeepSpeedConfigModel):
hostname: str = ""
tensor_parallel_ports: List[int] = []
torch_dist_port: int = None
torch_dist_port: Optional[int] = None
gpu_indices: List[int] = []
zmq_port: int = None
zmq_port: Optional[int] = None


class ModelConfig(MIIConfigModel):
class ModelConfig(DeepSpeedConfigModel):
model_name_or_path: str
"""
Model name or path of the model to HuggingFace model to be deployed.
Expand Down Expand Up @@ -192,8 +181,9 @@ class ModelConfig(MIIConfigModel):
def provider(self) -> ModelProvider:
return ModelProvider.HUGGING_FACE

@validator("device_map", pre=True)
def make_device_map_dict(cls, v):
@field_validator("device_map", mode="before")
@classmethod
def make_device_map_dict(cls, v: Any) -> Dict:
if isinstance(v, int):
return {"localhost": [[v]]}
if isinstance(v, list) and isinstance(v[0], int):
Expand All @@ -202,36 +192,36 @@ def make_device_map_dict(cls, v):
return {"localhost": v}
return v

@root_validator
@model_validator(mode="before")
@classmethod
def auto_fill_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
assert values.get("model_name_or_path"), "model_name_or_path must be provided"
if not values.get("tokenizer"):
values["tokenizer"] = values.get("model_name_or_path")
if not values.get("task"):
values["task"] = get_default_task(values.get("model_name_or_path"))
#if not values.get("task"):
# values["task"] = get_default_task(values.get("model_name_or_path"))
values["task"] = TaskType.TEXT_GENERATION
return values

@root_validator
def propagate_tp_size(cls, values: Dict[str, Any]) -> Dict[str, Any]:
tensor_parallel = values.get("tensor_parallel")
values.get("inference_engine_config").tensor_parallel.tp_size = tensor_parallel
return values

@root_validator
def propagate_quantization_mode(cls, values: Dict[str, Any]) -> Dict[str, Any]:
quantization_mode = values.get("quantization_mode")
values.get(
"inference_engine_config").quantization.quantization_mode = quantization_mode
return values
@model_validator(mode="after")
def propagate_tp_size(self) -> "ModelConfig":
self.inference_engine_config.tensor_parallel.tp_size = self.tensor_parallel
return self

@root_validator
def check_replica_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
num_replica_config = len(values.get("replica_configs"))
@model_validator(mode="after")
def check_replica_config(self) -> "ModelConfig":
num_replica_config = len(self.replica_configs)
if num_replica_config > 0:
assert num_replica_config == values.get("replica_num"), "Number of replica configs must match replica_num"
return values
assert num_replica_config == self.replica_num, "Number of replica configs must match replica_num"
return self

@model_validator(mode="after")
def propagate_quantization_mode(self) -> "ModelConfig":
self.inference_engine_config.quantization.quantization_mode = self.quantization_mode
return self


class MIIConfig(MIIConfigModel):
class MIIConfig(DeepSpeedConfigModel):
deployment_name: str = ""
"""
Name of the deployment. Used as an identifier for obtaining a inference
Expand All @@ -245,7 +235,7 @@ class MIIConfig(MIIConfigModel):
* `AML` will generate the assets necessary to deploy on AML resources.
"""

model_config: ModelConfig
model_conf: ModelConfig = Field(alias="model_config")
"""
Configuration for the deployed model(s).
"""
Expand Down Expand Up @@ -290,17 +280,18 @@ class MIIConfig(MIIConfigModel):
"""
AML instance type to use when create AML deployment assets.
"""
@root_validator(skip_on_failure=True)
def AML_name_valid(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("deployment_type") == DeploymentType.AML:
@model_validator(mode="after")
def AML_name_valid(self) -> "MIIConfig":
if self.deployment_type == DeploymentType.AML:
allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase +
string.digits + "-")
assert (
set(values.get("deployment_name")) <= allowed_chars
set(self.deployment_name) <= allowed_chars
), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'."
return values
return self

@root_validator(skip_on_failure=True)
@model_validator(mode="before")
@classmethod
def check_deployment_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
deployment_name = values.get("deployment_name")
if not deployment_name:
Expand All @@ -311,14 +302,14 @@ def check_deployment_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

def generate_replica_configs(self) -> None:
if self.model_config.replica_configs:
if self.model_conf.replica_configs:
return
torch_dist_port = self.model_config.torch_dist_port
tensor_parallel = self.model_config.tensor_parallel
torch_dist_port = self.model_conf.torch_dist_port
tensor_parallel = self.model_conf.tensor_parallel
replica_pool = _allocate_devices(self.hostfile,
tensor_parallel,
self.model_config.replica_num,
self.model_config.device_map)
self.model_conf.replica_num,
self.model_conf.device_map)
replica_configs = []
for i, (hostname, gpu_indices) in enumerate(replica_pool):
# Reserver port for a LB proxy when replication is enabled
Expand All @@ -332,10 +323,10 @@ def generate_replica_configs(self) -> None:
tensor_parallel_ports=tensor_parallel_ports,
torch_dist_port=replica_torch_dist_port,
gpu_indices=gpu_indices,
zmq_port=self.model_config.zmq_port_number + i,
zmq_port=self.model_conf.zmq_port_number + i,
))

self.model_config.replica_configs = replica_configs
self.model_conf.replica_configs = replica_configs


def _allocate_devices(hostfile_path: str,
Expand Down
2 changes: 1 addition & 1 deletion mii/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def mii_query_handle(deployment_name):
return MIINonPersistentClient(task, deployment_name)

mii_config = _get_mii_config(deployment_name)
return MIIClient(mii_config.model_config.task,
return MIIClient(mii_config.model_conf.task,
"localhost", # TODO: This can probably be removed
mii_config.port_number)

Expand Down
Loading
Loading