Skip to content

Commit

Permalink
fix tp in all paramter read
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochenyang20 committed Nov 27, 2024
1 parent 918c47b commit c293a6a
Show file tree
Hide file tree
Showing 4 changed files with 955 additions and 286 deletions.
14 changes: 13 additions & 1 deletion docs/backend/native_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,19 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"terminate_process(reward_process)"
]
Expand Down
67 changes: 22 additions & 45 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,50 +443,6 @@ def init_parameter_update_group(
logger.error(message)
return False, message

def update_parameter_from_distributed(self, name, dtype, shape, empty_cache=False):
"""
Update specific parameter in the model weights online through the process group.
Args:
name: the name of the parameter to be updated.
dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated.
empty_cache: whether to empty the cache after updating the parameter.
"""
target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
)

current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype

assert str(target_dtype) == str(
current_dtype
), f"dtype mismatch: target={dtype} vs current model runner={self.dtype}"
assert (
self._model_update_group is not None
), "model update group must be initialized"

try:

weights = torch.empty(shape, dtype=target_dtype, device=self.device)

torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
weights_dat = [(name, weights)]
self.model.load_weights(weights_dat)
if empty_cache:
torch.cuda.empty_cache()

return True, f"Succeeded to update parameter {name} online."

except Exception as e:
error_msg = (
f"Failed to update parameter online: {e}. "
f"The full weights of the ModelRunner are partially updated. "
f"Please discard the whole weights."
)
logger.error(error_msg)
return False, error_msg

def get_weights_by_parameter_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -527,14 +483,35 @@ def get_weights_by_parameter_name(
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim

# 提取对应部分的权重
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
# 处理 gate_up_proj 的情况
intermediate_size = self.model.config.intermediate_size
hidden_size = self.model.config.hidden_size
slice_size = intermediate_size // self.tp_size

if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size

# 提取对应部分的权重
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else:
weight = param.data

# 转换并截断
if self.tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(self.tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)

return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else:
logger.warning(
Expand Down
130 changes: 68 additions & 62 deletions test/srt/test_get_parameter_by_name.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import gc
import json
import time
import unittest
from itertools import product
from signal import signal

import numpy as np
import requests
Expand Down Expand Up @@ -31,32 +27,42 @@ def setUpClass(cls):

@classmethod
def init_engine_and_server(cls, engine_tp, server_tp, engine_dp, server_dp):
cls.engine = sgl.Engine(
model_path=cls.model,
random_seed=42,
tp_size=engine_tp,
dp_size=engine_dp,
base_gpu_id=0,
mem_fraction_static=0.85,
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--base-gpu-id",
str(engine_dp * engine_tp),
"--tp-size",
str(server_tp),
"--dp-size",
str(server_dp),
),
)
cls.engine = None
cls.process = None
cls.engine_dp = engine_dp
cls.server_dp = server_dp
cls.engine_tp = engine_tp
cls.server_tp = server_tp
if engine_dp != 0:
cls.engine = sgl.Engine(
model_path=cls.model,
random_seed=42,
tp_size=engine_tp,
dp_size=engine_dp,
base_gpu_id=0,
mem_fraction_static=0.85,
)
if server_dp != 0:
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--base-gpu-id",
str(engine_dp * engine_tp),
"--tp-size",
str(server_tp),
"--dp-size",
str(server_dp),
),
)

@classmethod
def close_engine_and_server(cls):
cls.engine.shutdown()
terminate_process(cls.process)
if cls.engine:
cls.engine.shutdown()
if cls.process:
terminate_process(cls.process)

@classmethod
def tearDownClass(cls):
Expand All @@ -66,46 +72,38 @@ def tearDownClass(cls):

@classmethod
def assert_update_weights_all_close(cls, param_name, truncate_size):
param = cls.hf_model.get_parameter(param_name)[:truncate_size]
engine_ret = cls.engine.get_weights_by_parameter_name(param_name, truncate_size)

# 如果 engine_ret 是标量值的列表
if isinstance(engine_ret, list) and len(engine_ret) == 2:
print("working in dp, engine_ret is a list of two elements")
np.testing.assert_allclose(engine_ret[0], engine_ret[1])
engine_ret = engine_ret[0]

engine_ret = np.array(engine_ret)

np.testing.assert_allclose(
engine_ret, param.cpu().detach().float().numpy(), rtol=1e-5, atol=1e-5
)

runtime_ret = requests.get(
f"{cls.base_url}/get_weights_by_parameter_name",
json={
"name": param_name,
"truncate_size": truncate_size,
},
).json()

# 处理 runtime_ret 的情况
if isinstance(runtime_ret, list) and len(runtime_ret) == 2:
np.testing.assert_allclose(runtime_ret[0], runtime_ret[1])
runtime_ret = runtime_ret[0]

np.testing.assert_allclose(
runtime_ret,
param.cpu().detach().float().numpy(),
rtol=1e-5,
atol=1e-5,
print(
f"param_name: {param_name}, engine_dp: {cls.engine_dp}, server_dp: {cls.server_dp}, engine_tp: {cls.engine_tp}, server_tp: {cls.server_tp}"
)
param = cls.hf_model.get_parameter(param_name)[:truncate_size]
param_np = param.cpu().detach().float().numpy()

if cls.engine:
engine_ret = cls.engine.get_weights_by_parameter_name(
param_name, truncate_size
)
engine_ret = cls._process_return(engine_ret)
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)

if cls.process:
runtime_ret = requests.get(
f"{cls.base_url}/get_weights_by_parameter_name",
json={"name": param_name, "truncate_size": truncate_size},
).json()
runtime_ret = cls._process_return(runtime_ret)
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)

@staticmethod
def _process_return(ret):
if isinstance(ret, list) and len(ret) == 2:
np.testing.assert_allclose(ret[0], ret[1])
return np.array(ret[0])
return np.array(ret)

@classmethod
def test_update_weights_unexist_model(cls):

assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
test_suits = [(1, 1, 1, 1)]
test_suits = [(1, 1, 1, 1), (2, 0, 1, 0), (0, 2, 0, 1)]

if torch.cuda.device_count() >= 4:
test_suits.extend([(2, 2, 1, 1), (1, 1, 2, 2)])
Expand All @@ -114,9 +112,17 @@ def test_update_weights_unexist_model(cls):
test_suits.append((2, 2, 2, 2))

parameters = [
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.1.self_attn.q_proj.weight",
"model.layers.2.self_attn.k_proj.weight",
"model.layers.3.self_attn.v_proj.weight",
"model.layers.4.self_attn.o_proj.weight",
"model.layers.5.mlp.gate_proj.weight",
"model.layers.6.mlp.up_proj.weight",
"model.layers.7.mlp.down_proj.weight",
"model.layers.8.post_attention_layernorm.weight",
"model.norm.weight",
"lm_head.weight",
]

Expand Down
Loading

0 comments on commit c293a6a

Please sign in to comment.