Skip to content
forked from BlinkDL/RWKV-LM

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

License

Notifications You must be signed in to change notification settings

blakeyy/RWKV-LM

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

The RWKV Language Model (and my LM tricks)

RWKV: RNN with Transformer-level LLM Performance

RWKV is a RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the "GPT" mode to quickly compute the hidden state for the "RNN" mode.

So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding (using the final hidden state).

Download RWKV-4 0.1/0.4/1.5/3/7/14B weights: https://huggingface.co/BlinkDL

RWKV chatbot: https://github.com/BlinkDL/ChatRWKV

HF space: https://huggingface.co/spaces/yahma/rwkv-14b

RWKV-chat

You can run RWKV on low VRAM GPUs with this pip package: https://github.com/harrisonvanderbyl/rwkvstic

Join our Discord: https://discord.gg/bDSBUMeFpc :)

You are welcome to join the RWKV discord https://discord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them.

Twitter: https://twitter.com/BlinkDL_AI

RWKV-eval2

RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M

GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M

Training speed: RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.

I am doing image experiments too (For example: https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion :) My idea: 256x256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32x32 grid -> pretend the grids are independent and "diffuse" using these probabilities.

Smooth training - no loss spikes! (lr & bsz change around 15G tokens) RWKV-loss

RWKV-eval

All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, so you can even run a LLM on your phone.

How it works: RWKV gathers information to a number of channels, which are also decaying with different speeds as you move to the next token. It's very simple once you understand it.

RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable). For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called "gates"), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance.

RWKV-formula

Here are some of my TODOs. Let's work together :)

  • HuggingFace integration (check huggingface/transformers#17230 ), and optimized CPU & iOS & Android & WASM & WebGL inference. RWKV is a RNN and very friendly for edge devices. Let's make it possible to run a LLM on your phone.

  • Test it on bidirectional & MLM tasks, and image & audio & video tokens. I think RWKV can support Encoder-Decoder via this: for each decoder token, use a learned mixture of [decoder previous hidden state] & [encoder final hidden state]. Hence all decoder tokens will have access to the encoder output.

  • Now training RWKV-4a with one single tiny extra attention (just a few extra lines comparing with RWKV-4) to further improve some difficult zeroshot tasks (such as LAMBADA) for smaller models. See https://github.com/BlinkDL/RWKV-LM/commit/a268cd2e40351ee31c30c5f8a5d1266d35b41829

User feedback:

I've so far toyed around the character-based model on our relatively small pre-training dataset (around 10GB of text), and the results are extremely good - similar ppl to models taking much, much longer to train.

dear god rwkv is fast. i switched to another tab after starting training it from scratch & when i returned it was emitting plausible english & maori words, i left to go microwave some coffee & when i came back it was producing fully grammatically correct sentences.

Tweet from Sepp Hochreiter (thank you!): https://twitter.com/HochreiterSepp/status/1524270961314484227

You can find me (BlinkDL) in the EleutherAI Discord too: https://www.eleuther.ai/get-involved/

RWKV-demo

Quick start

Use https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code, compatible with v4).

Here is a great prompt for testing Q&A of LLMs. Works for any model: (found by minimizing ChatGPT ppls for RWKV 1.5B)

prompt = f'\nQ & A\n\nQuestion:\n{qq}\n\nDetailed Expert Answer:\n' # let the model generate after this

Cool Community RWKV Projects (check them!):

https://pypi.org/project/rwkvstic/ Easy pip package (with 8bit & offload for low VRAM GPUs)

https://github.com/harrisonvanderbyl/rwkv_chatbot Chatbot using rwkvstic

https://github.com/hizkifw/WebChatRWKVstic WebUI (WIP)

https://github.com/gururise/rwkv_gradio RWKV Gradio

https://github.com/mrsteyk/RWKV-LM-deepspeed Another training fork

https://github.com/Blealtan/RWKV-LM-LoRA LoRA fine-tuning

https://github.com/wozeparrot/tinyrwkv RWKV in tinygrad (nice simple DL framework)

huggingface/transformers#17230 RWKV HF package (WIP)

https://github.com/ArEnSc/Production-RWKV RWKV HF package source

https://github.com/nlpodyssey/verbaflow RWKV in Go

https://github.com/nlpodyssey/rwkv RWKV in Go

https://github.com/mrsteyk/rwkvk-rs RWKV in Rust

https://github.com/imxcstar/CSharp-RWKV-V4 RWKV in C#

https://github.com/resloved/RWKV-notebooks RWKV colab notebooks

https://colab.research.google.com/github/harrisonvanderbyl/rwkvstic/blob/master/notebooks/chatbot.ipynb RWKV chatbot colab notebook

https://github.com/Pathos14489/RWKVDistributedInference RWKV Distributed Inference

https://github.com/AXKuhta/rwkv-onnx-dml RWKV ONNX

https://github.com/josephrocca/rwkv-v4-web RWKV-v4 running in the browser (simple demo. greedy decode)

Inference

Run RWKV-4 Pile models: Download models from https://huggingface.co/BlinkDL. Set TOKEN_MODE = 'pile' in run.py and run it. It's fast even on CPU (the default mode).

Colab for RWKV-4 Pile 1.5B: https://colab.research.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM

Run RWKV-4 Pile models in your browser (and onnx version): see this issue BlinkDL#7

RWKV-4 Web Demo: https://josephrocca.github.io/rwkv-v4-web/demo/ (note: only greedy sampling for now)

For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.github.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).

I'd like to build an almost-INT8 version of RWKV. A simple method to quantize a matrix with outliers:

import numpy as npA

# the original M, with outliers
M = np.array([[1,   2,   1,  2],[2,  100,    2, 10],[1,   2,   1, 2],[2,   1, 20, 1]])

# the scaled M, without outliers
Q = np.array([[1, 0.2, 0.1,  2],[0.4,  2, 0.04, 2], [1, 0.2, 0.1, 2],[2, 0.1,  2, 1]])
# we can find optimal a & b to minimize inference error after quantization
a = np.array([1, 10, 10, 1])
b = np.array([1, 5, 1, 1])

# test M.v with random v - the results will be the same
v = np.array([1.23, 5.44, 9.75, 2.98])
print(M.dot(v))
print(Q.dot(v * a) * b)

# even better: decompose M.dot(v) as Q.dot(v * a + aa) * b + bb where aa & bb are vectors too
# and can apply more scaling to achieve W8A8 (example: https://arxiv.org/pdf/2211.10438.pdf)

Training / Fine-tuning

Training RWKV-4 from scratch: run train.py, which by default is using the enwik8 dataset (unzip https://data.deepai.org/enwik8.zip).

You will be training the "GPT" version because it's paralleziable and faster to train. RWKV-4 can extrapolate, so training with ctxLen 1024 can work for ctxLen of 2500+. You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens.

Fine-tuning RWKV-4 Pile models: use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then set EXPRESS_PILE_MODE to True in train.py, and run it.

Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you should begin with .xx and .aa/.bb (.aa divided by .bb).

Colab for fine-tuning RWKV-4 Pile models: https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb

Large corpus: Use https://github.com/EleutherAI/gpt-neox to convert .jsonl into .bin and .idx

python tools/preprocess_data.py --input ./my_data.jsonl --output-prefix ./data/my_data --vocab ./20B_tokenizer.json --dataset-impl mmap --tokenizer-type HFTokenizer --append-eod

The jsonl format sample (one line for each document):

{"meta": {"ID": 101}, "text": "This is the first document."}
{"meta": {"ID": 102}, "text": "Hello\nWorld"}
{"meta": {"ID": 103}, "text": "1+1=2\n1+2=3\n2+2=4"}

generated by code like this:

ss = json.dumps({"meta": meta, "text": text}, ensure_ascii=False)
out.write(ss + "\n")

New ideas (just to record some new ideas)

I have an idea to improve tokenization. We can hardcode some channels to have meanings. Example:

Channel 0 = "space"

Channel 1 = "capitalize first letter"

Channel 2 = "capitalize all letters"

Therefore:

Embedding of "abc": [0, 0, 0, x0, x1, x2 , ..]

Embedding of " abc": [1, 0, 0, x0, x1, x2, ..]

Embedding of " Abc": [1, 1, 0, x0, x1, x2, ..]

Embedding of "ABC": [0, 0, 1, x0, x1, x2, ...]

......

so they will share most of the embedding. And we can rapidly compute the output probability of all variations of "abc".

Note: the above method is assuming that p(" xyz") / p("xyz") is the same for any "xyz", which can be wrong.

Better: define emb_space emb_capitalize_first emb_capitalize_all to be a function of emb.

Maybe the Best: let 'abc' ' abc' etc. to share the last 90% of their embeddings.

At this moment, all our tokenizers spend too many items to represent all variations of 'abc' ' abc' ' Abc' etc. Moreover the model cannot discover that these are actually similar if some of these variations are rare in the dataset. My method can solve this. I plan to test this in a new version of RWKV.

How it works

RWKV is inspired by Apple's AFT (https://arxiv.org/abs/2105.14103).

Moreover it's using a number of my tricks, such as:

The pseudocode (execution from top to bottom):

RWKV-v2-RNN

The a b c d factors work together to build a time-decay curve: [X, 1, W, W^2, W^3, ...].

Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:

  • a and b: EMAs of kv and k.
  • c and d: these are a and b combined with "self-attention".

kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.

The R-gate is important for performance. k = info strength of this token (to be passed to future tokens). r = whether to apply the info to this token.

RWKV-3 improvements

Use different trainable TimeMix factors for R / K / V in SA and FF layers. Example:

xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

Use preLN instead of postLN (more stable & faster convergence):

if self.layer_id == 0:
	x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))

Explaining the code for RWKV-3 GPT mode

The GPT mode - overview

The building blocks of RWKV-3 GPT mode are similar to that of a usual preLN GPT.

The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training.

x = self.emb(idx)  # input: idx = token indices
x = self.ln_emb(x) # extra LN after embedding
x = x + self.att_0(self.ln_att_0(x)) # preLN
x = x + self.ffn_0(self.ln_ffn_0(x))
...
x = x + self.att_n(self.ln_att_n(x))
x = x + self.ffn_n(self.ln_ffn_n(x))
x = self.ln_head(x) # final LN before projection
x = self.head(x)    # output: x = logits

It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github.com/BlinkDL/SmallInitEmb.

For the 1.5B RWKV-3, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G.

batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small.

For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99).

Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens.

The GPT mode - ATT block

The RWKV-3 does not have any attention in the usual sense, but we will call this block ATT anyway.

B, T, C = x.size() # x = (Batch,Time,Channel)

# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v

# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)

# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
if RUN_DEVICE == 'cuda':
    wkv = TimeX.apply(w, kv, B,C,T, 0)
    wk = TimeX.apply(w, k, B,C,T, K_EPS)
else:
    w = w[:,-T:].unsqueeze(1)
    wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
    wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS

# The RWKV formula
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
rwkv = self.output(rwkv) # final output projection

The self.key, self.receptance, self.output matrices are all initialized to zero.

The time_mix, time_decay, time_first vectors are transferred from a smaller trained model (note: I sort & smooth them too).

The GPT mode - FFN block

The FFN block has three tricks comparing with the usual GPT:

  1. My time_mix trick.

  2. The sqReLU from the Primer paper.

  3. An extra receptance-gate (similar to the receptance-gate in ATT block).

# Mix x with the previous timestep to produce xk, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

# The usual FFN operation
k = self.key(xk)
k = torch.square(torch.relu(k)) # from the Primer paper
kv = self.value(k)

# Apply an extra receptance-gate to kv
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv

The self.value, self.receptance matrices are all initialized to zero.

RWKV-4 improvements

RWKV-v3-plan

From GPT to RWKV (the formulas)

Let F[t] be the system state at t.

Let x[t] be the new external input at t.

In GPT, predicting F[t+1] requires considering F[0], F[1], .. F[t]. So it takes O(T^2) to generate a length T sequence.

The simplified formula for GPT:

F[\mathrm{t}+1]=\frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}])}

It's very capable in theory, however that does not mean we can fully utilize its capability with usual optimizers. I suspect the loss landscape is too difficult for our current methods.

Compare with the simplified formula for RWKV (the parallel mode, looks similar to Apple's AFT):

F[\mathrm{t}+1]=\sigma(\mathbf{R}x[\mathrm{t}]) \cdot \frac{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) \cdot(\mathbf{V}F[\mathrm{i}])}{\sum_{\mathrm{i}=0}^{\mathrm{t}} \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K }F[\mathrm{i}])}

The R, K, V are trainable matrices, and W is a trainable vector (time-decay factor for each channel).

In GPT, the contribution of F[i] to F[t+1] is weighted by  \exp (\mathbf{Q}x[\mathrm{t}] * \mathbf{K}F[\mathrm{i}]) .

In RWKV-2, the contribution of F[i] to F[t+1] is weighted by \sigma(\mathbf{R}x[\mathrm{t}]) \cdot \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) \cdot \exp (\mathbf{K}F[\mathrm{i}]) .

  • The \sigma is a non-linearity and we can use sigmoid.
  • Note \sigma(\mathbf{R}x[\mathrm{t}]) is not in the denominator, and I call R the "receptance".
  • The \exp (\mathbf{W} \cdot(\mathrm{t}-\mathrm{i})) is the time-decay factor. I proposed the same idea (scaling the attention by distance) in Aug 2020 and called it the "time-weighting" (check the commit history of https://github.com/BlinkDL/minGPT-tuned).

Here comes the punchline: we can rewrite it into a RNN (recursive formula). Note:

F[1]=\sigma(\mathbf{R }x[0]) \cdot \frac{ \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{\exp (\mathbf{K }F[0])}

F[2]=\sigma(\mathbf{R }x[1]) \cdot \frac{ \exp (\mathbf{K }F[1]) \cdot(\mathbf{V }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0]) \cdot(\mathbf{V }F[0])}{ \exp (\mathbf{K }F[1])+\exp (\mathbf{W} ) \cdot \exp (\mathbf{K }F[0])}

Therefore it's straightforward to verify:

F[t+1]=\sigma(\mathbf{R }x[t]) \cdot \frac{\exp (\mathbf{K}F[\mathrm{t}]) \cdot(\mathbf{V}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot A[\mathrm{t}]}{ \exp (\mathbf{K}F[\mathrm{t}])+\exp (\mathbf{W}) \cdot B[\mathrm{t}]}

where A[t] and B[t] are the numerator and denominator of the previous step, respectively.

I believe RWKV is performant because W is like repeatedly applying a diagonal matrix. Note (P^{-1} D P)^n = P^{-1} D^n P, so it is similar to repeatedly applying a general diagonalizable matrix.

Moreover it's possible to turn it into a continuous ODE (a bit similar to State Space Models). I will write about it later.

Star History

Star History Chart

Multimodal ideas

I have an idea for [text --> 32x32 RGB image] using a LM (transformer, RWKV, etc.). Will test it soon.

Firstly, LM loss (instead of L2 loss), so the image will not be blurry.

Secondly, color quantization. For example, only allowing 8 levels for R/G/B. Then the image vocab size is 8x8x8 = 512 (for each pixel), instead of 2^24. Therefore, a 32x32 RGB image = a len1024 sequence of vocab512 (image tokens), which is a typical input for usual LMs. (Later we can use diffusion models to upsample and generate RGB888 images. We might be able to use a LM for this too.)

Thirdly, 2D positional embeddings that are easy for the model to understand. For example, add one-hot X & Y coords to the first 64(=32+32) channels. Say if the pixel is at x=8, y=20, then we will add 1 to channel 8 and channel 52 (=32+20). Moreover probably we can add the float X & Y coords (normalized to 0~1 range) to another 2 channels. And other periodic pos. encoding might help too (will test).

Finally, RandRound when doing the color quantization in the DataLoader. For example, if the float level is 4.578, then there is a 57.8% chance to use 5, and (1-57.8%) chance to use 4. And we can allow both 4 and 5 in the prediction, but the loss will be higher if the prediction is 4.

Multi-task training might help too. I will try this dataset format: [TxtFirst] [Desc of Img (txt tokens)] [Img] [img tokens] and sometimes [ImgFirst] [img tokens] [Txt] [Desc of Img (txt tokens)] ... the order of the imgs should be randomized in the DataLoader, and [TxtFirst] [ImgFirst] [Img] [Txt] are special tokens and do random sampling of the full dataset. So sometimes the model will see the img tokens first and then the corresponding txt tokens, which is a [img -> txt] task. And the model will see some partial imgs and partial txts. I think a char-level LM might help the model to write correct text on images.

How to sample a large dataset (for training)

I am using a trick to sample the Pile deterministically yet randomly enough.

Let's say the pile has x chunks (a chunk = ctx_len tokens).

pick a prime number p just less than x, and make sure p = 2 (mod 3).

Use (step * step * step) mod p to sample it. Add some bias to step for extra randomness.

The top-p-x sampling method (for inference)

We propose a new sampling method called top-p-x:

it's like top-p, and the only difference is you also keep all tokens whose prob > x.

Try x = 0.01 first.

Better Learning Rate Schedule via Variantional Method of Loss Curve

I propose a simple new method to find better LR schedules. The method is cost-efficient and practical for large LMs. The takeaway is we can model the loss curve dynamics (phenomenology) w.r.t. the LR, and a nice closed-form LR curve can be directly computed from it using variantional method. Moreover we can predict the final loss with reasonable accuracy.

UPDATE: In "Conclusion 1.", use the best-fitting regime (ignore the initial steps where our approximations break down) to fit the parameters.

Try this: fixed lr for 1 hr, then exponential decay to 0.2 * lr in 12 hrs, and choose the t=[1hr, 13hr] segment.

In the last three plots, black = predicted loss curve of the new LR schedule, blue = original (unoptimized) real loss curve, orange = new LR schedule.

better_lr_schedule

RWKV v1

We propose the RWKV language model, with alternating time-mix and channel-mix layers:

\begin{align*}
\text{Time-mix :} && \text{TM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_{u} &&\textbf{W}_{t,u,c} &&\cdot&& \text{softmax}_t(\text{K}_{u,c}) &&\cdot&& \text{V}_{u,c}\\
\text{Channel-mix :} && \text{CM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_d &&\textbf{W}_{c,d} &&\cdot&& \text{gelu}(\text{K}_{t,d}) &&\cdot&& \text{V}_{t,d}
\end{align*}

  • The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.

  • The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.

(1) We changed the normalization (denominator). For masked language models, we define:

\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}

(UPDATE: We are using the original AFT normalization in v2)

Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence.

(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):

W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)

Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.

(UPDATE: We remove α β γ factors in v2-RNN and restrict W to be of a simple form and hence able to rewrite it as RNN)

Token-shift (time-shift mixing)

The token-shift explicitly uses (half the channels of this token) & (half the channels of prev token) to generate all vectors (QKV, RWKV, ...).

self.time_shift = nn.ZeroPad2d((0,0,1,-1))

x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

Dividing channels by 2 and shift-1 works great for char-level English and char-level Chinese LM.

However for BPE-level English LM, it's only effective if your embedding is large enough (at least 1024 - so the usual small L12-D768 model is not enough).

My theory on the effectiveness of token-shift:

When we train a GPT, the hidden representation of a token has to accomplish two different objects:

  1. Predict the next token. Sometimes this is easy (obvious next token).

  2. Collect all previous context info, so later tokens can use it. This is always hard.

The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer.

You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. I also found you may want to use less mixing in higher layers.

p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :)

The Head-QK Trick: learning to copy and avoid tokens

In usual transformer, a small model has difficulty copying tokens (such as person names) in the context. We add extra Q & K to the final output such that the model can directly copy (or avoid) tokens in the context. Afterwards the model will teach itself NER (named entity recognition) if you look at the learned weights.

q = self.head_q(x)[:,:T,:] # projecting to 256-d
k = self.head_k(x)[:,:T,:] # projecting to 256-d
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()       
x = self.head(x) + c

Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).

The top-a sampling method

We also propose a new sampling method called top-a (as in src/utils.py):

(1) Find the max probability p_max after softmax.

(2) Remove all entries whose probability is lower than 0.2 * pow(p_max, 2). So it's adaptive, hence "top-a".

(3) Feel free to tune the 0.2 and 2 factor. Tune 0.2 first.

The idea of top-a:

  1. If max_prob=0.9, then remove all tokens with prob < 0.162 (so, removing all alternatives)
  2. If max_prob=0.5, then remove all tokens with prob < 0.05 (so, allowing more choices)
  3. If max_prob=0.1, then remove all tokens with prob < 0.002 (so, allowing lots of possibilities)
probs = F.softmax(logits, dim=-1)

limit = torch.pow(torch.max(probs), 2) * 0.02
logits[probs < limit] = -float('Inf')

Performance

Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip

RWKV-vs-MHA

Gray: usual MHA+Rotary+GeGLU - performance not as good. 17.2M params.

Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance. 16.6M params.

Green: MHA+Rotary+GeGLU+Token_shift. 17.2M params.

Blue: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance. 16.6M params.

@software{peng_bo_2021_5196578,
  author       = {PENG Bo},
  title        = {BlinkDL/RWKV-LM: 0.01},
  month        = aug,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {0.01},
  doi          = {10.5281/zenodo.5196577},
  url          = {https://doi.org/10.5281/zenodo.5196577}
}

Initialization

We use careful initialization for RWKV to get fast convergence - orthogonal matrices with proper scaling, and special time_w curves. Check model.py for details.

Some learned time_w examples:

RWKV-time-w

About

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 92.9%
  • Cuda 5.8%
  • C++ 1.3%