Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port fp8 mixtral #460

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,13 @@ def __init__(
)

# For model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}

self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
load_format=server_args.load_format,
trust_remote_code=server_args.trust_remote_code,
server_args_dict=server_args_dict,
server_args=server_args,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
Expand Down
28 changes: 16 additions & 12 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model


Expand Down Expand Up @@ -218,36 +219,39 @@ def __init__(
tp_rank,
tp_size,
nccl_port,
load_format="auto",
trust_remote_code=True,
server_args_dict: dict = {},
server_args: ServerArgs,
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
self.load_format = load_format
self.trust_remote_code = trust_remote_code
self.server_args = server_args

global global_server_args_dict
global_server_args_dict = server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}

# Init torch distributed
logger.debug("Init torch begin.")
torch.cuda.set_device(self.tp_rank)
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)

initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.debug("Init torch end.")

total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.load_model()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.init_memory_pool(total_gpu_memory)

self.is_multimodal_model = is_multimodal_model(self.model_config)
Expand All @@ -256,15 +260,15 @@ def load_model(self):
logger.info(f"Rank {self.tp_rank}: load weight begin.")

device_config = DeviceConfig()
load_config = LoadConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig(
model=self.model_config.path,
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.model_config.trust_remote_code,
trust_remote_code=self.server_args.trust_remote_code,
dtype=torch.float16,
seed=42,
revision=self.model_config.revision,
skip_tokenizer_init=True,
)
if self.model_config.model_overide_args is not None:
Expand All @@ -279,7 +283,7 @@ def load_model(self):
parallel_config=None,
scheduler_config=None,
)
logger.info(f"Rank {self.tp_rank}: load weight end.")
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}")

def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
Expand Down
Loading