Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for RWKV #1237

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4c7cb11
inital tp commits
jahatef Jun 4, 2024
46904d5
setup
jahatef Jun 19, 2024
e2933ef
configs
jahatef Sep 25, 2024
d1112ab
merge
jahatef Oct 3, 2024
43d641d
time mixing tp
jahatef Oct 3, 2024
de02f37
time-mixing
jahatef Oct 11, 2024
dd441b6
time mixing debugging
jahatef Oct 12, 2024
a418670
reset time_faaaa
jahatef Oct 13, 2024
540d856
Add additional asserts and update post training readme (#1300)
AI-WAIFU Oct 8, 2024
12aac35
Fix failling tests (#1301)
AI-WAIFU Oct 8, 2024
97c7915
inital tp commits
jahatef Jun 4, 2024
5f89ed8
merge
jahatef Nov 5, 2024
91cb759
Add ERROR logging prefix and sort the prefixes alphabetically (#1308)
TheBatmanofButler Oct 17, 2024
49b263a
inital tp commits
jahatef Jun 4, 2024
48de682
cleanup
jahatef Nov 6, 2024
c6fac96
cleanup
jahatef Nov 6, 2024
5a259c0
Update local_setup.yml
jahatef Nov 6, 2024
c2d6c85
add Triton FLA
jahatef Nov 10, 2024
bdb3658
change version of rwkv-fla
jahatef Nov 12, 2024
ff7f328
fix a GQA issue (#1314) (#1315)
tiandeyu-cs Nov 13, 2024
1350b2c
fix 'intermediate_size' in Llama configuration files after the 'mlp_t…
tiandeyu-cs Nov 13, 2024
c4d7a54
Python 3.10 support (#1313)
markNZed Nov 13, 2024
ee2f142
Fix documentation for converting SFT/DPO weights back to HF Llama (#1…
jacobthebanana Nov 13, 2024
6e81f0b
fix bug (#1311)
AI-WAIFU Nov 13, 2024
df95419
Add support for dropout in sparse attention (#1312)
michaelc-yu Nov 16, 2024
d682529
adds pyproject files and tests (#1302)
LouisCastricato Nov 16, 2024
0bc11d6
undo merge error (#1325)
Quentin-Anthony Nov 27, 2024
c6db95c
inital tp commits
jahatef Jun 4, 2024
daac503
setup
jahatef Jun 19, 2024
bf478ce
Merge branch 'main' into rwkv-tp
Quentin-Anthony Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/local_setup.yml
Original file line number Diff line number Diff line change
@@ -22,6 +22,10 @@
"load": "checkpoints",
"checkpoint_validation_with_forward_pass": False,


# "launcher": "openmpi",
#"deepspeed_mpi": true,

"tensorboard_dir": "tensorboard",
"log_dir": "logs",
}
103 changes: 103 additions & 0 deletions configs/rwkv/1.5B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 24,
"hidden_size": 2048,
"num_attention_heads": 32, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 4,

"attention_config": [[["rwkv"], 24]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,
"seed": 1234,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
103 changes: 103 additions & 0 deletions configs/rwkv/430M.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 24,
"hidden_size": 1024,
"num_attention_heads": 16, # head_size = dim_att / num_attention_heads.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here. Calling these attention heads is highly misleading.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of disagree, as rwkv code generally references time mixing as attention, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.

# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 1,

"attention_config": [[["rwkv"], 24]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,
"seed": 1234,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
102 changes: 102 additions & 0 deletions configs/rwkv/7B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 32,
"hidden_size": 4096,
"num_attention_heads": 64, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 8,

"attention_config": [[["rwkv"], 32]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 500,
"lr_decay_iters": 500,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
1 change: 1 addition & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
@@ -258,6 +258,7 @@ def init_specs(self):
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
layer_number=i,
)
)
205 changes: 150 additions & 55 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,17 @@
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.cpp_extension import load

from megatron import mpu
from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region
try:
from fla.ops.rwkv6 import chunk_rwkv6
import einops
except ModuleNotFoundError:
print(
"Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \
or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last point "or use CUDA kernels" is confusing. Can you add a "by doing xyz" so that users know what you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder^

)
pass

class WKV(torch.autograd.Function):
"""
@@ -95,6 +105,18 @@ def backward(ctx, gy):
def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u):
return WKV.apply(B, T, C, H, r, k, v, w, u)

@torch.compiler.disable(recursive=True)
# torch.compiler introduces errors in numerical precision (torch 2.4)
def RUN_FLA_CHUNK(B, T, C, H, r, k, v, w, u, h=None, scale=1.0, chunk_size=32):
r = r.view(B,T,H,-1).transpose(1,2)
k = k.view(B,T,H,-1).transpose(1,2)
v = v.view(B,T,H,-1).transpose(1,2)
# u can be 3d or 2d (B, H, -1) or just (H, -1) to save VRAM
w = -torch.exp(w.view(B,T,H,-1).transpose(1,2))
# change to scale=-1.0 when using fp16, this will apply scale to r and k.
o, final_state = chunk_rwkv6(r, k, v, w, u=u, scale=scale, initial_state=h,
output_final_state=False, chunk_size=chunk_size) #initial_state=None and output_final_state=False for rwkv6
return o.transpose(1,2).reshape(B,T,C), final_state

# RWKV6 time mix
class RWKV_TimeMix(nn.Module):
@@ -104,7 +126,7 @@ class RWKV_TimeMix(nn.Module):
TODO: fix jit compiling.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this based on the parser issue we discussed? I think it's worth testing just-jit and reordered jit and heuristics like I suggested before merging with this TODO

"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, layer_number, init_method):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number
@@ -172,14 +194,46 @@ def __init__(self, neox_args, layer_number):
)

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.dim_att, bias=False
)
self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)

self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)
self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False)
self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)
self.receptance = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_att,
gather_output=False,
init_method=init_method,
bias=False,
)
self.key = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_att,
gather_output=False,
init_method=init_method,
bias=False,
)
self.value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_att,
gather_output=False,
init_method=init_method,
bias=False,
)
self.output = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.dim_att,
output_size=neox_args.hidden_size,
gather_output=True,
init_method=init_method,
bias=False,
)
self.gate = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_att,
gather_output=True,
init_method=init_method,
bias=False,
)
self.ln_x = nn.GroupNorm(
neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2)
)
@@ -200,13 +254,15 @@ def jit_func(self, x):
xr = x + xx * (self.time_maa_r + mr)
xg = x + xx * (self.time_maa_g + mg)

r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
r, _ = self.receptance(xr)
k, _ = self.key(xk)
v, _ = self.value(xv)
gated, _ = self.gate(xg)
g = F.silu(gated)

ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
w = self.time_decay + ww
w = scatter_to_model_parallel_region(w)

return r, k, v, g, w

@@ -215,28 +271,39 @@ def jit_func_2(self, x, g):
x = x.view(B * T, C)

x = self.ln_x(x).view(B, T, C)
x = self.output(x * g)
x, _ = self.output(x * g)

return x

def forward(self, x):
B, T, C = x.size()
H = self.neox_args.num_attention_heads
C_tp = C//mpu.get_model_parallel_world_size()
H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size()

r, k, v, g, w = self.jit_func(x)
x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa)
if self.neox_args.rwkv_fla:
x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H))
else:
x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H))

x = gather_from_model_parallel_region(x)

return self.jit_func_2(x, g)


class RWKV_ChannelMix(nn.Module):
class ParallelRWKV_ChannelMix(nn.Module):
"""
Channel Mix layer. The ffn in RWKV
"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, layer_number, init_method):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number

world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad(): # fancy init of time_mix
@@ -247,38 +314,60 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False)
self.key = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.ffn_dim,
gather_output=False,
init_method=init_method,
bias=False,
)

self.receptance = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
gather_output=True,
init_method=init_method,
bias=False
)
self.value = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.ffn_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=init_method,
parallel_output=False,
bias=False
)

def forward(self, x):
xx = self.time_shift(x) - x
xk = x + xx * self.time_maa_k
xr = x + xx * self.time_maa_r

k = self.key(xk)
k, _ = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
kv, _ = self.value(k)
receptance, _ = self.receptance(xr)
return torch.sigmoid(receptance) * kv


class RWKVResidualLayer(nn.Module):
"""
RWKV layer definition
"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, init_method, layer_number):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number
self.fp16 = neox_args.precision == "fp16"
self.bf16 = neox_args.precision == "bfloat16"
assert (
neox_args.intermediate_size == None or neox_args.expansion_factor == None
), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"
if not hasattr(neox_args, "dim_att"):
), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv"
if not neox_args.dim_att:
neox_args.dim_att = neox_args.hidden_size
if neox_args.intermediate_size:
neox_args.ffn_dim = neox_args.intermediate_size
@@ -297,43 +386,45 @@ def __init__(self, neox_args, layer_number):
self.num_attention_heads = neox_args.num_attention_heads
assert neox_args.dim_att % self.num_attention_heads == 0

self.init_method = init_method
if neox_args.attention_dropout > 0:
self.drop0 = nn.Dropout(p=neox_args.attention_dropout)

self.ln1 = nn.LayerNorm(neox_args.hidden_size)
self.ln2 = nn.LayerNorm(neox_args.hidden_size)

self.att = RWKV_TimeMix(neox_args, layer_number)
self.att = RWKV_TimeMix(neox_args, layer_number, init_method=init_method)

self.ffn = RWKV_ChannelMix(neox_args, layer_number)
self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method)

if neox_args.attention_dropout > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another attention arg for rwkv. Can we decouple attn dropout from rwkv?

self.drop0 = nn.Dropout(p=neox_args.attention_dropout)
if neox_args.hidden_dropout > 0:
self.drop1 = nn.Dropout(p=neox_args.hidden_dropout)

if layer_number == 0:
global wkv_cuda
"""
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not.
"""
wkv_cuda = load(
name="wkv6",
sources=[
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp",
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={self.neox_args.head_size}",
f"-D_T_={self.neox_args.seq_length}",
],
)
if not self.neox_args.rwkv_fla:
global wkv_cuda
"""
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not.
"""
wkv_cuda = load(
name="wkv6",
sources=[
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp",
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={self.neox_args.head_size}",
f"-D_T_={self.neox_args.seq_length}",
],
)

def forward(self, x):
neox_args = self.neox_args
@@ -353,7 +444,6 @@ def forward(self, x):

return x


class RWKVResidualLayerPipe(RWKVResidualLayer):
"""
RWKV Pipeline Layer
@@ -363,4 +453,9 @@ def forward(self, args):
assert len(args) == 2
hidden_states, mask = args
neox_args = self.neox_args
return super().forward(hidden_states), mask
if self.layer_number == 0:
hidden_states = hidden_states.transpose(0,1)
hidden_states = super().forward(hidden_states)
if self.layer_number == self.neox_args.num_layers-1:
hidden_states = hidden_states.transpose(0,1)
return hidden_states, mask
8 changes: 1 addition & 7 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
@@ -1113,17 +1113,11 @@ def calculate_derived(self):
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba"
assert (
self.hidden_dropout == 0.0,
self.hidden_dropout != 0.0,
), "Mamba does not yet have dropout implemented"
if "rwkv" in self.attention_config:
assert (
self.model_parallel_size == 1
), "RWKV not currently compatible with model parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV"
assert (
self.hidden_dropout == 0.0,
), "RWKV does not yet have dropout implemented"

# Sparsity config
if self.sparsity_config is None:
6 changes: 5 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
@@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
}
"""

rwkv_fla: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regen neox_arguments.md, since this isn't showing up there.

"""
Whether to use the Flash Linear Attention implementation of the RWKV kernel, or the CUDA kernel version.
"""

num_unique_layers: int = None
"""
Number of unique transformer layers. num-layers should be divisible by this value. Currently only has an effect when pipe_parallel_size=0.
@@ -497,7 +502,6 @@ class NeoXArgsModel(NeoXArgsTemplate):

# Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905)
output_layer_parallelism: Literal["column"] = "column"

"""
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""
1 change: 1 addition & 0 deletions requirements/requirements-rwkv.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
git+https://github.com/sustcsonglin/flash-linear-attention