Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sequence parallel strategy support. #734

Merged
merged 2 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Model:
module: "GPTModule"
name: "GPT"
fused_linear: False
sequence_parallel: False


Data:
Expand Down
129 changes: 106 additions & 23 deletions ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from paddle.distributed.fleet.utils import recompute
import sys

from .sequence_parallel_utils import ScatterOp, GatherOp, \
all_reduce_gradient_hook, ColumnSequenceParallelLinear, RowSequenceParallelLinear


def get_attr(layer, name):
if getattr(layer, name, None) is not None:
Expand Down Expand Up @@ -87,7 +90,8 @@ def __init__(self,
num_partitions=1,
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
recompute_granularity="full",
sequence_parallel=False):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -98,6 +102,14 @@ def __init__(self,
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.sequence_parallel = sequence_parallel

if sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Expand All @@ -110,39 +122,39 @@ def __init__(self,
assert self.kdim == embed_dim
assert self.vdim == embed_dim

self.qkv_proj = fleet.meta_parallel.ColumnParallelLinear(
self.qkv_proj = ColumnParallelLinear(
embed_dim,
3 * embed_dim,
weight_attr=weight_attr,
has_bias=True,
gather_output=False,
fuse_matmul_bias=fused_linear)
else:
self.q_proj = fleet.meta_parallel.ColumnParallelLinear(
self.q_proj = ColumnParallelLinear(
embed_dim,
embed_dim,
weight_attr=weight_attr,
has_bias=True,
gather_output=False,
fuse_matmul_bias=fused_linear)

self.k_proj = fleet.meta_parallel.ColumnParallelLinear(
self.k_proj = ColumnParallelLinear(
self.kdim,
embed_dim,
weight_attr=weight_attr,
has_bias=True,
gather_output=False,
fuse_matmul_bias=fused_linear)

self.v_proj = fleet.meta_parallel.ColumnParallelLinear(
self.v_proj = ColumnParallelLinear(
self.vdim,
embed_dim,
weight_attr=weight_attr,
has_bias=True,
gather_output=False,
fuse_matmul_bias=fused_linear)

self.out_proj = fleet.meta_parallel.RowParallelLinear(
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
weight_attr=weight_attr,
Expand All @@ -154,7 +166,10 @@ def _fuse_prepare_qkv(self, query, use_cache=False, cache=None):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
if self.sequence_parallel:
mix_layer = paddle.transpose(mix_layer, [1, 2, 0, 3])
else:
mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)

assert not isinstance(
Expand All @@ -179,7 +194,10 @@ def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
"""
q = self.q_proj(query)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if self.sequence_parallel:
q = tensor.transpose(x=q, perm=[1, 2, 0, 3])
else:
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

if isinstance(cache, self.StaticCache):
# for encoder-decoder attention in inference and has cached
Expand Down Expand Up @@ -211,9 +229,15 @@ def compute_kv(self, key, value):
k = self.k_proj(key)
v = self.v_proj(value)
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
if self.sequence_parallel:
k = tensor.transpose(x=k, perm=[1, 2, 0, 3])
else:
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.sequence_parallel:
v = tensor.transpose(x=v, perm=[1, 2, 0, 3])
else:
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
return k, v

def gen_cache(self, key, value=None, type=Cache):
Expand Down Expand Up @@ -263,7 +287,12 @@ def core_attn(self, q, k, v, attn_mask=None):
out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
if self.sequence_parallel:
out = tensor.transpose(out, perm=[2, 0, 1, 3])
else:
out = tensor.transpose(out, perm=[0, 2, 1, 3])
# If sequence_parallel is true, out shape is [s, b, h] after reshape
# else out shape is [b, s, h]
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

return out, weights
Expand All @@ -281,6 +310,10 @@ def forward(self,
"""
key = query if key is None else key
value = query if value is None else value
# if sequence_parallel is true, query, key, value shape are [s, b, h],
# else their shape are [b, s, h], n is mp parallelism.
# no matter sequence_parallel is true or false,
# after reshape, q, k, v shape should be [b, num_heads/n, s, head_dim]
# compute q ,k ,v
if use_cache is False:
if self.fuse:
Expand All @@ -302,6 +335,8 @@ def forward(self,
out, weights = self.core_attn(q, k, v, attn_mask=attn_mask)

# project to output
# if sequence_parallel is true, out shape are [s/n, b, h],
# else their shape are [b, s, h], n is mp parallelism.
out = self.out_proj(out)

outs = [out]
Expand All @@ -323,16 +358,23 @@ def __init__(self,
norm=None,
hidden_size=None,
use_recompute=False,
recompute_granularity="full"):
recompute_granularity="full",
sequence_parallel=False):
super(TransformerDecoder, self).__init__()

self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.sequence_parallel = sequence_parallel
if norm == "LayerNorm":
self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
# if sequence parallel is true,
# register hook to all_reduce gradient of weight, bias
if self.sequence_parallel:
self.norm.weight.register_hook(all_reduce_gradient_hook)
self.norm.bias.register_hook(all_reduce_gradient_hook)
elif norm is not None:
raise ValueError("Only support LayerNorm")

Expand Down Expand Up @@ -416,7 +458,8 @@ def __init__(self,
fused_linear=False,
recompute_attn=False,
use_recompute=False,
recompute_granularity="full"):
recompute_granularity="full",
sequence_parallel=False):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -427,6 +470,14 @@ def __init__(self,
self.normalize_before = normalize_before
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.sequence_parallel = sequence_parallel

if sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
Expand All @@ -440,17 +491,18 @@ def __init__(self,
num_partitions=num_partitions,
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)
recompute_granularity=recompute_granularity,
sequence_parallel=sequence_parallel)

self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
self.linear1 = ColumnParallelLinear(
d_model,
dim_feedforward,
weight_attr=weight_attrs[2],
gather_output=False,
has_bias=True,
fuse_matmul_bias=fused_linear)

self.linear2 = fleet.meta_parallel.RowParallelLinear(
self.linear2 = RowParallelLinear(
dim_feedforward,
d_model,
weight_attr=weight_attrs[2],
Expand All @@ -460,6 +512,12 @@ def __init__(self,

self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
if self.sequence_parallel:
# if sequence parallel is true, register hook to all_reduce gradient of bias
self.norm1.weight.register_hook(all_reduce_gradient_hook)
self.norm2.weight.register_hook(all_reduce_gradient_hook)
self.norm1.bias.register_hook(all_reduce_gradient_hook)
self.norm2.bias.register_hook(all_reduce_gradient_hook)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
Expand Down Expand Up @@ -524,9 +582,11 @@ def __init__(self,
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02):
initializer_range=0.02,
sequence_parallel=False):
super(GPTEmbeddings, self).__init__()

self.sequence_parallel = sequence_parallel
self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size,
Expand All @@ -550,7 +610,15 @@ def forward(self, input_ids, position_ids=None):
input_embedings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings)
# if sequence parallel is true, change embedding shape [b, s, h] to [s, b, h]
# set the sequence dim as first, so the split in sequence dim is data-continuous
if self.sequence_parallel:
embeddings = paddle.transpose(embeddings, perm=[1, 0, 2])
embeddings = ScatterOp.apply(embeddings)
with get_rng_state_tracker().rng_state('local_seed'):
embeddings = self.dropout(embeddings)
else:
embeddings = self.dropout(embeddings)
return embeddings


Expand All @@ -569,7 +637,8 @@ def __init__(self,
num_partitions=1,
use_recompute=False,
fused_linear=False,
recompute_granularity="full"):
recompute_granularity="full",
sequence_parallel=False):

super(GPTModelHybrid, self).__init__()

Expand All @@ -579,7 +648,9 @@ def __init__(self,

self.embeddings = GPTEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range)
max_position_embeddings, type_vocab_size, self.initializer_range,
sequence_parallel)
self.sequence_parallel = sequence_parallel

decoder_layers = nn.LayerList()
for i in range(num_layers):
Expand All @@ -599,15 +670,17 @@ def __init__(self,
num_partitions=num_partitions,
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))
recompute_granularity=recompute_granularity,
sequence_parallel=sequence_parallel))

self.decoder = TransformerDecoder(
decoder_layers,
num_layers,
norm="LayerNorm",
hidden_size=hidden_size,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)
recompute_granularity=recompute_granularity,
sequence_parallel=sequence_parallel)

def forward(self,
input_ids,
Expand All @@ -628,6 +701,8 @@ def forward(self,
# .expand_as(input_ids)
position_ids = paddle.fluid.layers.expand_as(position_ids,
input_ids)
# if sequence_parallel is true, embedding_output shape is [s/n, b, h]
# else its shape is [b, s, h], n is mp parallelism
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids)

Expand Down Expand Up @@ -655,6 +730,10 @@ def forward(self,
use_cache=use_cache,
cache=cache)

if self.sequence_parallel:
encoder_outputs = GatherOp.apply(encoder_outputs)
encoder_outputs = paddle.transpose(encoder_outputs, [1, 0, 2])

return encoder_outputs


Expand Down Expand Up @@ -799,11 +878,15 @@ def __init__(self,
use_recompute=False,
fused_linear=False,
recompute_granularity="full",
virtual_pp_degree=1):
virtual_pp_degree=1,
sequence_parallel=False):

# forward desc
self.descs = []

assert sequence_parallel is False, "Sequence parallel strategy \
is not supported in GPTForPretrainingPipe model now."

self.descs.append(
SharedLayerDesc(
'embed',
Expand Down
Loading