-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor parallelism for RWKV #1237
base: main
Are you sure you want to change the base?
Changes from all commits
4c7cb11
46904d5
e2933ef
d1112ab
43d641d
de02f37
dd441b6
a418670
540d856
12aac35
97c7915
5f89ed8
91cb759
49b263a
48de682
c6fac96
5a259c0
c2d6c85
bdb3658
ff7f328
1350b2c
c4d7a54
ee2f142
6e81f0b
df95419
d682529
0bc11d6
c6db95c
daac503
bf478ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
{ | ||
# Parallelism is not yet supported for rwkv | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
"num_layers": 24, | ||
"hidden_size": 2048, | ||
"num_attention_heads": 32, # head_size = dim_att / num_attention_heads. | ||
# head_size is 64 for all rwkv models | ||
"seq_length": 4096, | ||
"max_position_embeddings": 4096, | ||
"output_layer_parallelism": "column", | ||
"norm": "rmsnorm", | ||
"rms_norm_epsilon": 1.0e-5, | ||
"train_micro_batch_size_per_gpu": 4, | ||
|
||
"attention_config": [[["rwkv"], 24]], | ||
|
||
"activation": "silu", | ||
|
||
# model settings | ||
|
||
#"pos_emb": "rotary", | ||
"rotary_pct": 0.25, | ||
"no_weight_tying": true, | ||
"gpt_j_residual": true, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0008, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00008, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"data_impl": "mmap", | ||
"num_workers": 1, | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"bf16": { | ||
"bf16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 12, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 320000, | ||
"lr_decay_iters": 320000, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "constant", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 100, | ||
"eval_interval": 100000, | ||
"eval_iters": 10, | ||
"seed": 1234, | ||
|
||
# logging | ||
"log_interval": 10, | ||
"steps_per_print": 10, | ||
"wall_clock_breakdown": true, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
{ | ||
# Parallelism is not yet supported for rwkv | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
"num_layers": 24, | ||
"hidden_size": 1024, | ||
"num_attention_heads": 16, # head_size = dim_att / num_attention_heads. | ||
# head_size is 64 for all rwkv models | ||
"seq_length": 4096, | ||
"max_position_embeddings": 4096, | ||
"output_layer_parallelism": "column", | ||
"norm": "rmsnorm", | ||
"rms_norm_epsilon": 1.0e-5, | ||
"train_micro_batch_size_per_gpu": 1, | ||
|
||
"attention_config": [[["rwkv"], 24]], | ||
|
||
"activation": "silu", | ||
|
||
# model settings | ||
|
||
#"pos_emb": "rotary", | ||
"rotary_pct": 0.25, | ||
"no_weight_tying": true, | ||
"gpt_j_residual": true, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0008, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00008, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"data_impl": "mmap", | ||
"num_workers": 1, | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"bf16": { | ||
"bf16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 12, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 320000, | ||
"lr_decay_iters": 320000, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "constant", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 100, | ||
"eval_interval": 100000, | ||
"eval_iters": 10, | ||
"seed": 1234, | ||
|
||
# logging | ||
"log_interval": 10, | ||
"steps_per_print": 10, | ||
"wall_clock_breakdown": true, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
{ | ||
# Parallelism is not yet supported for rwkv | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
"num_layers": 32, | ||
"hidden_size": 4096, | ||
"num_attention_heads": 64, # head_size = dim_att / num_attention_heads. | ||
# head_size is 64 for all rwkv models | ||
"seq_length": 4096, | ||
"max_position_embeddings": 4096, | ||
"output_layer_parallelism": "column", | ||
"norm": "rmsnorm", | ||
"rms_norm_epsilon": 1.0e-5, | ||
"train_micro_batch_size_per_gpu": 8, | ||
|
||
"attention_config": [[["rwkv"], 32]], | ||
|
||
"activation": "silu", | ||
|
||
# model settings | ||
|
||
#"pos_emb": "rotary", | ||
"rotary_pct": 0.25, | ||
"no_weight_tying": true, | ||
"gpt_j_residual": true, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0008, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00008, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"data_impl": "mmap", | ||
"num_workers": 1, | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"bf16": { | ||
"bf16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 12, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 500, | ||
"lr_decay_iters": 500, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "constant", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 100, | ||
"eval_interval": 100000, | ||
"eval_iters": 10, | ||
|
||
# logging | ||
"log_interval": 10, | ||
"steps_per_print": 10, | ||
"wall_clock_breakdown": true, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,17 @@ | |
import torch.nn as nn | ||
from torch.nn import functional as F | ||
from torch.utils.cpp_extension import load | ||
|
||
from megatron import mpu | ||
from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region | ||
try: | ||
from fla.ops.rwkv6 import chunk_rwkv6 | ||
import einops | ||
except ModuleNotFoundError: | ||
print( | ||
"Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \ | ||
or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This last point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reminder^ |
||
) | ||
pass | ||
|
||
class WKV(torch.autograd.Function): | ||
""" | ||
|
@@ -95,6 +105,18 @@ def backward(ctx, gy): | |
def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): | ||
return WKV.apply(B, T, C, H, r, k, v, w, u) | ||
|
||
@torch.compiler.disable(recursive=True) | ||
# torch.compiler introduces errors in numerical precision (torch 2.4) | ||
def RUN_FLA_CHUNK(B, T, C, H, r, k, v, w, u, h=None, scale=1.0, chunk_size=32): | ||
r = r.view(B,T,H,-1).transpose(1,2) | ||
k = k.view(B,T,H,-1).transpose(1,2) | ||
v = v.view(B,T,H,-1).transpose(1,2) | ||
# u can be 3d or 2d (B, H, -1) or just (H, -1) to save VRAM | ||
w = -torch.exp(w.view(B,T,H,-1).transpose(1,2)) | ||
# change to scale=-1.0 when using fp16, this will apply scale to r and k. | ||
o, final_state = chunk_rwkv6(r, k, v, w, u=u, scale=scale, initial_state=h, | ||
output_final_state=False, chunk_size=chunk_size) #initial_state=None and output_final_state=False for rwkv6 | ||
return o.transpose(1,2).reshape(B,T,C), final_state | ||
|
||
# RWKV6 time mix | ||
class RWKV_TimeMix(nn.Module): | ||
|
@@ -104,7 +126,7 @@ class RWKV_TimeMix(nn.Module): | |
TODO: fix jit compiling. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this based on the parser issue we discussed? I think it's worth testing just-jit and reordered jit and heuristics like I suggested before merging with this TODO |
||
""" | ||
|
||
def __init__(self, neox_args, layer_number): | ||
def __init__(self, neox_args, layer_number, init_method): | ||
super().__init__() | ||
self.neox_args = neox_args | ||
self.layer_number = layer_number | ||
|
@@ -172,14 +194,46 @@ def __init__(self, neox_args, layer_number): | |
) | ||
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | ||
self.receptance = nn.Linear( | ||
neox_args.hidden_size, neox_args.dim_att, bias=False | ||
) | ||
self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) | ||
|
||
self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) | ||
self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) | ||
self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) | ||
self.receptance = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.hidden_size, | ||
output_size=neox_args.dim_att, | ||
gather_output=False, | ||
init_method=init_method, | ||
bias=False, | ||
) | ||
self.key = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.hidden_size, | ||
output_size=neox_args.dim_att, | ||
gather_output=False, | ||
init_method=init_method, | ||
bias=False, | ||
) | ||
self.value = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.hidden_size, | ||
output_size=neox_args.dim_att, | ||
gather_output=False, | ||
init_method=init_method, | ||
bias=False, | ||
) | ||
self.output = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.dim_att, | ||
output_size=neox_args.hidden_size, | ||
gather_output=True, | ||
init_method=init_method, | ||
bias=False, | ||
) | ||
self.gate = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.hidden_size, | ||
output_size=neox_args.dim_att, | ||
gather_output=True, | ||
init_method=init_method, | ||
bias=False, | ||
) | ||
self.ln_x = nn.GroupNorm( | ||
neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) | ||
) | ||
|
@@ -200,13 +254,15 @@ def jit_func(self, x): | |
xr = x + xx * (self.time_maa_r + mr) | ||
xg = x + xx * (self.time_maa_g + mg) | ||
|
||
r = self.receptance(xr) | ||
k = self.key(xk) | ||
v = self.value(xv) | ||
g = F.silu(self.gate(xg)) | ||
r, _ = self.receptance(xr) | ||
k, _ = self.key(xk) | ||
v, _ = self.value(xv) | ||
gated, _ = self.gate(xg) | ||
g = F.silu(gated) | ||
|
||
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 | ||
w = self.time_decay + ww | ||
w = scatter_to_model_parallel_region(w) | ||
|
||
return r, k, v, g, w | ||
|
||
|
@@ -215,28 +271,39 @@ def jit_func_2(self, x, g): | |
x = x.view(B * T, C) | ||
|
||
x = self.ln_x(x).view(B, T, C) | ||
x = self.output(x * g) | ||
x, _ = self.output(x * g) | ||
|
||
return x | ||
|
||
def forward(self, x): | ||
B, T, C = x.size() | ||
H = self.neox_args.num_attention_heads | ||
C_tp = C//mpu.get_model_parallel_world_size() | ||
H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() | ||
|
||
r, k, v, g, w = self.jit_func(x) | ||
x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) | ||
if self.neox_args.rwkv_fla: | ||
x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) | ||
else: | ||
x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) | ||
|
||
x = gather_from_model_parallel_region(x) | ||
|
||
return self.jit_func_2(x, g) | ||
|
||
|
||
class RWKV_ChannelMix(nn.Module): | ||
class ParallelRWKV_ChannelMix(nn.Module): | ||
""" | ||
Channel Mix layer. The ffn in RWKV | ||
""" | ||
|
||
def __init__(self, neox_args, layer_number): | ||
def __init__(self, neox_args, layer_number, init_method): | ||
super().__init__() | ||
self.neox_args = neox_args | ||
self.layer_number = layer_number | ||
|
||
world_size = mpu.get_model_parallel_world_size() | ||
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) | ||
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | ||
|
||
with torch.no_grad(): # fancy init of time_mix | ||
|
@@ -247,38 +314,60 @@ def __init__(self, neox_args, layer_number): | |
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) | ||
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) | ||
|
||
self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) | ||
self.receptance = nn.Linear( | ||
neox_args.hidden_size, neox_args.hidden_size, bias=False | ||
) | ||
self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) | ||
self.key = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.hidden_size, | ||
output_size=neox_args.ffn_dim, | ||
gather_output=False, | ||
init_method=init_method, | ||
bias=False, | ||
) | ||
|
||
self.receptance = mpu.ColumnParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.hidden_size, | ||
output_size=neox_args.hidden_size, | ||
gather_output=True, | ||
init_method=init_method, | ||
bias=False | ||
) | ||
self.value = mpu.RowParallelLinear( | ||
neox_args=neox_args, | ||
input_size=neox_args.ffn_dim, | ||
output_size=neox_args.hidden_size, | ||
input_is_parallel=True, | ||
init_method=init_method, | ||
parallel_output=False, | ||
bias=False | ||
) | ||
|
||
def forward(self, x): | ||
xx = self.time_shift(x) - x | ||
xk = x + xx * self.time_maa_k | ||
xr = x + xx * self.time_maa_r | ||
|
||
k = self.key(xk) | ||
k, _ = self.key(xk) | ||
k = torch.relu(k) ** 2 | ||
kv = self.value(k) | ||
return torch.sigmoid(self.receptance(xr)) * kv | ||
kv, _ = self.value(k) | ||
receptance, _ = self.receptance(xr) | ||
return torch.sigmoid(receptance) * kv | ||
|
||
|
||
class RWKVResidualLayer(nn.Module): | ||
""" | ||
RWKV layer definition | ||
""" | ||
|
||
def __init__(self, neox_args, layer_number): | ||
def __init__(self, neox_args, init_method, layer_number): | ||
super().__init__() | ||
self.neox_args = neox_args | ||
self.layer_number = layer_number | ||
self.fp16 = neox_args.precision == "fp16" | ||
self.bf16 = neox_args.precision == "bfloat16" | ||
assert ( | ||
neox_args.intermediate_size == None or neox_args.expansion_factor == None | ||
), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" | ||
if not hasattr(neox_args, "dim_att"): | ||
), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv" | ||
if not neox_args.dim_att: | ||
neox_args.dim_att = neox_args.hidden_size | ||
if neox_args.intermediate_size: | ||
neox_args.ffn_dim = neox_args.intermediate_size | ||
|
@@ -297,43 +386,45 @@ def __init__(self, neox_args, layer_number): | |
self.num_attention_heads = neox_args.num_attention_heads | ||
assert neox_args.dim_att % self.num_attention_heads == 0 | ||
|
||
self.init_method = init_method | ||
if neox_args.attention_dropout > 0: | ||
self.drop0 = nn.Dropout(p=neox_args.attention_dropout) | ||
|
||
self.ln1 = nn.LayerNorm(neox_args.hidden_size) | ||
self.ln2 = nn.LayerNorm(neox_args.hidden_size) | ||
|
||
self.att = RWKV_TimeMix(neox_args, layer_number) | ||
self.att = RWKV_TimeMix(neox_args, layer_number, init_method=init_method) | ||
|
||
self.ffn = RWKV_ChannelMix(neox_args, layer_number) | ||
self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) | ||
|
||
if neox_args.attention_dropout > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another attention arg for rwkv. Can we decouple attn dropout from rwkv? |
||
self.drop0 = nn.Dropout(p=neox_args.attention_dropout) | ||
if neox_args.hidden_dropout > 0: | ||
self.drop1 = nn.Dropout(p=neox_args.hidden_dropout) | ||
|
||
if layer_number == 0: | ||
global wkv_cuda | ||
""" | ||
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. | ||
""" | ||
wkv_cuda = load( | ||
name="wkv6", | ||
sources=[ | ||
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp", | ||
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", | ||
], | ||
verbose=True, | ||
extra_cuda_cflags=[ | ||
"-res-usage", | ||
"--use_fast_math", | ||
"-O3", | ||
"-Xptxas -O3", | ||
"--extra-device-vectorization", | ||
f"-D_N_={self.neox_args.head_size}", | ||
f"-D_T_={self.neox_args.seq_length}", | ||
], | ||
) | ||
if not self.neox_args.rwkv_fla: | ||
global wkv_cuda | ||
""" | ||
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. | ||
""" | ||
wkv_cuda = load( | ||
name="wkv6", | ||
sources=[ | ||
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp", | ||
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", | ||
], | ||
verbose=True, | ||
extra_cuda_cflags=[ | ||
"-res-usage", | ||
"--use_fast_math", | ||
"-O3", | ||
"-Xptxas -O3", | ||
"--extra-device-vectorization", | ||
f"-D_N_={self.neox_args.head_size}", | ||
f"-D_T_={self.neox_args.seq_length}", | ||
], | ||
) | ||
|
||
def forward(self, x): | ||
neox_args = self.neox_args | ||
|
@@ -353,7 +444,6 @@ def forward(self, x): | |
|
||
return x | ||
|
||
|
||
class RWKVResidualLayerPipe(RWKVResidualLayer): | ||
""" | ||
RWKV Pipeline Layer | ||
|
@@ -363,4 +453,9 @@ def forward(self, args): | |
assert len(args) == 2 | ||
hidden_states, mask = args | ||
neox_args = self.neox_args | ||
return super().forward(hidden_states), mask | ||
if self.layer_number == 0: | ||
hidden_states = hidden_states.transpose(0,1) | ||
hidden_states = super().forward(hidden_states) | ||
if self.layer_number == self.neox_args.num_layers-1: | ||
hidden_states = hidden_states.transpose(0,1) | ||
return hidden_states, mask |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate): | |
} | ||
""" | ||
|
||
rwkv_fla: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regen |
||
""" | ||
Whether to use the Flash Linear Attention implementation of the RWKV kernel, or the CUDA kernel version. | ||
""" | ||
|
||
num_unique_layers: int = None | ||
""" | ||
Number of unique transformer layers. num-layers should be divisible by this value. Currently only has an effect when pipe_parallel_size=0. | ||
|
@@ -497,7 +502,6 @@ class NeoXArgsModel(NeoXArgsTemplate): | |
|
||
# Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905) | ||
output_layer_parallelism: Literal["column"] = "column" | ||
|
||
""" | ||
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
git+https://github.com/sustcsonglin/flash-linear-attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here. Calling these attention heads is highly misleading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of disagree, as rwkv code generally references time mixing as
attention
, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.