Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM Inference] support llama3.1 #8929

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 62 additions & 28 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,10 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):

try:
self.rope_theta = self.model_config.rope_theta
self.rope_scaling = self.model_config.rope_scaling
except:
self.rope_theta = 10000.0
self.rope_scaling = None

self.pre_cache_length = 0

Expand Down Expand Up @@ -874,7 +876,7 @@ def init_model_inputs(self, config: PredictorArgument):
shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64"
)
self.model_inputs["rope_emb"] = self._get_rotary_position_embedding(
paddle.arange(config.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta
paddle.arange(config.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta, self.rope_scaling
)
self.model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64")
self.model_inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool")
Expand Down Expand Up @@ -909,7 +911,7 @@ def init_model_inputs(self, config: PredictorArgument):
alibi_decoder + (1 - self.model_inputs["tgt_mask"]) * paddle.finfo(self.dtype).min
).cast(self.dtype)

def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0):
def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0, rope_scaling: dict = None):
"""
Pre-calculate rotary position embedding for position_ids.

Expand All @@ -924,6 +926,33 @@ def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=1000
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
inv_freq = rope_theta ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)

if rope_scaling is not None:
rope_type = rope_scaling.get("rope_type", None)
if rope_type is not None and rope_type == "llama3":
factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings", 8192)

low_freq_wavelen = original_max_position_embeddings / low_freq_factor
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in inv_freq:
import math

wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = paddle.to_tensor(new_freqs, dtype=inv_freq.dtype)

# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, 1, D]
Expand Down Expand Up @@ -1029,24 +1058,28 @@ def predict(self, input_texts: list[str], return_tokens=False):
tensor_queue.put(output_tensor)

read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue])
read_res_process.start()
if self.tensor_parallel_rank == 0:
read_res_process.start()

s_time = time.time()
while self.model_inputs["not_need_stop"]:
self._infer(self.model_inputs)
logger.info(f"running spend {time.time() - s_time}")

outputs = []
output_tokens = []
while len(outputs) < self.batch_size:
result = result_queue.get(timeout=1)
outputs.append(result[-1])
output_tokens.append(result[-2])
if self.tensor_parallel_rank == 0:
outputs = []
output_tokens = []
while len(outputs) < self.batch_size:
result = result_queue.get(timeout=1)
outputs.append(result[-1])
output_tokens.append(result[-2])

read_res_process.terminate()
read_res_process.terminate()

if return_tokens:
return outputs, output_tokens
else:
return outputs
if return_tokens:
return outputs, output_tokens
else:
return outputs


class StaticBlockInferencePredictor(BlockInferencePredictorMixin):
Expand Down Expand Up @@ -1112,8 +1145,6 @@ def _create_predictor(self, predictor_args: PredictorArgument):
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
# config.disable_glog_info()
# config.enable_memory_optim()

if predictor_args.device == "npu":
import paddle_custom_device.npu.passes as passes
Expand Down Expand Up @@ -1149,26 +1180,29 @@ def predict(self, input_texts: list[str], return_tokens=False):
tensor_queue.put(output_tensor)

read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue])
read_res_process.start()

if self.tensor_parallel_rank == 0:
read_res_process.start()

s_time = time.time()
while self.model_inputs["not_need_stop"]:
self.predictor.run(list(self.model_inputs.values()))
logger.info(f"running spend {time.time() - s_time}")

outputs = []
output_tokens = []
while len(outputs) < self.batch_size:
result = result_queue.get(timeout=1)
outputs.append(result[-1])
output_tokens.append(result[-2])
if self.tensor_parallel_rank == 0:
outputs = []
output_tokens = []
while len(outputs) < self.batch_size:
result = result_queue.get(timeout=1)
outputs.append(result[-1])
output_tokens.append(result[-2])

read_res_process.terminate()
read_res_process.terminate()

if return_tokens:
return outputs, output_tokens
else:
return outputs
if return_tokens:
return outputs, output_tokens
else:
return outputs


def get_ptq_multicards_num(directory):
Expand Down
6 changes: 3 additions & 3 deletions llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_lora_target_modules(model):
".*up_proj.*",
".*down_proj.*",
]
else:
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
return target_modules

Expand Down Expand Up @@ -763,9 +763,9 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q

while True:
get_output(output_tensor, 0, True)
if output_tensor[0, 0] == -2: # read none
if int(output_tensor[0, 0]) == -2: # read none
continue
bsz = output_tensor[1, 0].numpy()
bsz = int(output_tensor[1, 0])
output_numpy = output_tensor[2 : bsz + 2].numpy()
output_numpy[output_numpy == -1] = 2
outputs.append(output_numpy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def __init__(
self.num_heads = num_heads
if kv_num_heads > 0:
self.kv_num_heads = kv_num_heads
assert nranks == 1, "nranks should be 1 for kv_num_heads > 0"
else:
self.kv_num_heads = num_heads
self.dim_feedforward = dim_feedforward
Expand Down
3 changes: 1 addition & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,7 @@

# sample
probs = F.softmax(logits)
# _, next_tokens = top_p_sampling(probs, top_p, -1)
_, next_tokens = paddle.topk(probs, 1, -1)
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

Check warning on line 675 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L675

Added line #L675 was not covered by tests

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)
Expand Down
65 changes: 53 additions & 12 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,7 @@
self.quant_type
)

if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
)
else:
self.embed_tokens = nn.Embedding(
self.vocab_size,
self.hidden_size,
)
self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)

Check warning on line 374 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L374

Added line #L374 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多卡推理遇到两个问题,1. block_attn为false时动转静会报错,2. block_attn为true时运行会报错


# get ring_id
ring_id = -1
Expand Down Expand Up @@ -1256,6 +1246,58 @@
self.llama = LlamaInferenceModel(config)
self.lm_head = LlamaLMHead(config)

@classmethod
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):

Check warning on line 1250 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1249-L1250

Added lines #L1249 - L1250 were not covered by tests

from paddlenlp.transformers.conversion_utils import split_or_merge_func

Check warning on line 1252 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1252

Added line #L1252 was not covered by tests

fn = split_or_merge_func(

Check warning on line 1254 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1254

Added line #L1254 was not covered by tests
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)

def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}

Check warning on line 1262 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1261-L1262

Added lines #L1261 - L1262 were not covered by tests

base_actions = {

Check warning on line 1264 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1264

Added line #L1264 was not covered by tests
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}

# Column Linear
if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)

Check warning on line 1273 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1272-L1273

Added lines #L1272 - L1273 were not covered by tests
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)

Check warning on line 1275 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1275

Added line #L1275 was not covered by tests
# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)

Check warning on line 1279 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1277-L1279

Added lines #L1277 - L1279 were not covered by tests

if config.fuse_attention_ffn:
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(

Check warning on line 1282 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1281-L1282

Added lines #L1281 - L1282 were not covered by tests
fn, is_column=True, is_naive_2fuse=True
)
else:
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)

Check warning on line 1287 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1286-L1287

Added lines #L1286 - L1287 were not covered by tests

for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action

Check warning on line 1293 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1289-L1293

Added lines #L1289 - L1293 were not covered by tests

return final_actions

Check warning on line 1295 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1295

Added line #L1295 was not covered by tests

mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)

Check warning on line 1297 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1297

Added line #L1297 was not covered by tests

return mappings

Check warning on line 1299 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1299

Added line #L1299 was not covered by tests

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)
Expand Down Expand Up @@ -1435,7 +1477,6 @@
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}
Expand Down
Loading