Skip to content

Commit

Permalink
Fix time checkpointing.
Browse files Browse the repository at this point in the history
The previous implementation misses all chunks backward pass but the
last. Fixing this makes deepspeed stage 2 not working though.

Also making the states really checkpointed through expanding classes
into raw tensors.
  • Loading branch information
Blealtan committed Jun 28, 2023
1 parent 1d60989 commit 6792641
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 57 deletions.
3 changes: 2 additions & 1 deletion RWKV-v4neo/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ trainer:
# For more details see:
# https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed-zero-stage-2
#
strategy: deepspeed_stage_2_offload
#!FIXME: currently only deepspeed_stage_1 is supported, due to that deepspeed cannot handle repeated backward hook.
strategy: deepspeed_stage_1

# Floating point precision for the model, because RWKV is built FOR bf16
# you should pretty much never change this setting
Expand Down
155 changes: 99 additions & 56 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import os, math
import gc, math
from random import randint
from typing import List, Optional

import numpy as np
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F

import lightning as L
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
from lightning.pytorch.strategies import DeepSpeedStrategy
Expand All @@ -32,32 +34,56 @@

class TimeMixState:

def __init__(self, token_shift_state: torch.Tensor,
wkv_state: torch.Tensor):
self.token_shift_state = token_shift_state
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
self.shift_state = shift_state
self.wkv_state = wkv_state


class ChannelMixState:

def __init__(self, token_shift_state: torch.Tensor):
self.token_shift_state = token_shift_state
def __init__(self, shift_state: torch.Tensor):
self.shift_state = shift_state


class BlockState:

def __init__(self, time_mix_state: torch.Tensor,
channel_mix_state: torch.Tensor):
def __init__(self, time_mix_state: TimeMixState,
channel_mix_state: ChannelMixState):
self.time_mix_state = time_mix_state
self.channel_mix_state = channel_mix_state


def init_block_state(B, C, device, dtype):
wkv_state = torch.zeros((B, C, 3), device=device, dtype=torch.float)
wkv_state[:, :, -1] = -1e38
token_shift_state = torch.zeros((B, C), device=device, dtype=dtype)
return BlockState(TimeMixState(token_shift_state, wkv_state),
ChannelMixState(token_shift_state))
class BlockStateList:

def __init__(self, shift_states, wkv_states):
self.wkv_states = wkv_states
self.shift_states = shift_states

@staticmethod
def create(N, B, C, device, dtype):
result = BlockStateList.empty(N, B, C, device, dtype)
result.wkv_states[:] = 0
result.wkv_states[:, :, :, -1] = -1e38
result.shift_states[:] = 0
return result

@staticmethod
def empty(N, B, C, device, dtype):
wkv_states = torch.empty((N, B, C, 3),
device=device,
dtype=torch.float)
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
return BlockStateList(shift_states, wkv_states)

def __getitem__(self, layer: int):
return BlockState(
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
ChannelMixState(self.shift_states[layer, 1]))

def __setitem__(self, layer: int, state: BlockState):
self.shift_states[layer, 0] = state.time_mix_state.shift_state
self.wkv_states[layer] = state.time_mix_state.wkv_state
self.shift_states[layer, 1] = state.channel_mix_state.shift_state


from torch.utils.cpp_extension import load
Expand Down Expand Up @@ -109,8 +135,8 @@ def __init__(self, layer_id, n_layer, n_embd, dim_att):
@MyFunction
def forward(self, x, last_state: TimeMixState):
# Mix x with the previous timestep to produce xk, xv, xr
xx = torch.concat(
(last_state.token_shift_state.unsqueeze(1), x[:, :-1]), dim=1)
xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]),
dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
Expand Down Expand Up @@ -148,8 +174,8 @@ def __init__(self, layer_id, n_layer, n_embd, dim_ffn):

@MyFunction
def forward(self, x, last_state: ChannelMixState):
xx = torch.concat(
(last_state.token_shift_state.unsqueeze(1), x[:, :-1]), dim=1)
xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]),
dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
Expand Down Expand Up @@ -213,7 +239,7 @@ def backward(ctx, grad_output):
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
gy=gy*ctx.currentMask[:,None][None,:]
gy = gy * ctx.currentMask[:, None][None, :]
return (grad_output, gy, None, None)


Expand Down Expand Up @@ -367,26 +393,30 @@ def deepspeed_offload(self) -> bool:
return "offload_optimizer" in cfg or "offload_parameters" in cfg
return False

def forward(self, idx, last_states: List[BlockState]):
def forward(self, idx: torch.Tensor, last_shift_states: torch.Tensor,
last_wkv_states: torch.Tensor):
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, model ctx_len is exhausted."

x = self.emb(idx)

new_states = []
for block, last_state in zip(self.blocks, last_states):
new_states = BlockStateList.empty(self.n_layer, B, self.n_embd,
x.device, x.dtype)
for i, (block, last_state) in enumerate(
zip(self.blocks,
BlockStateList(last_shift_states, last_wkv_states))):
if self.grad_cp:
x, new_state = deepspeed.checkpointing.checkpoint(
block, x, last_state)
else:
x, new_state = block(x, last_state)
new_states.append(new_state)
new_states[i] = new_state

x = self.ln_out(x)

x = self.head(x)

return x, new_states
return x, new_states.shift_states, new_states.wkv_states

def compute_loss(self, batch, batch_idx, do_cutoff: bool):
seq = batch['input_ids']
Expand All @@ -396,14 +426,15 @@ def compute_loss(self, batch, batch_idx, do_cutoff: bool):
# Check if attent mask is set, if not initialize it
if seq_mask is None or seq_mask.ndim != 2:
seq_mask = torch.ones_like(seq[:, 1:])

if do_cutoff:
prev_step = 0
for step, len_cut in zip(self.ctx_len_warmup_steps,
self.ctx_len_cutoffs):
if prev_step <= self.global_step < step and len_cut < seq.shape[1] - 1:
self.ctx_len_cutoffs):
if prev_step <= self.global_step < step and len_cut < seq.shape[
1] - 1:
pos = randint(0, seq.shape[1] - len_cut - 1)

# Original
# seq = seq[:, pos:pos + len_cut + 1]

Expand All @@ -421,60 +452,72 @@ def compute_loss(self, batch, batch_idx, do_cutoff: bool):
C = self.n_embd
total_mask_sum = torch.sum(seq_mask)

def checkpointed_step(idx, targets, mask, prev_loss, last_states,
prev_steps):
logits, new_states = self(idx, last_states)
def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
last_wkv_states, prev_steps):
logits, new_shift_states, new_wkv_states = self(
idx, last_shift_states, last_wkv_states)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
targets.view(-1), reduction="none")
targets.view(-1),
reduction="none")
submask = mask.view(-1)[:loss.shape[0]]
submask_sum=torch.sum(submask)
submask_sum = torch.sum(submask)

# Special handling of empty mask
# Special handling of empty mask
# (possible when real_ctx_len is larger then ctx_len, which results into 'chunking')
if(submask_sum==0):
if submask_sum == 0:
loss = torch.sum(loss * submask) / 1
loss = L2Wrap.apply(loss, logits, total_mask_sum, submask)
new_steps = prev_steps # + submask_sum
new_loss = prev_loss+loss
return new_loss, new_states, new_steps

# Handling with mask
loss = torch.sum(loss * submask) / submask_sum
loss = L2Wrap.apply(loss, logits, total_mask_sum, submask)
new_steps = prev_steps + submask_sum
new_loss = prev_loss * (prev_steps / new_steps) + loss * (
1 - prev_steps / new_steps)
return new_loss, new_states, new_steps

total_loss = torch.tensor(0, dtype=self.emb.weight.dtype)
new_steps = prev_steps # + submask_sum
new_loss = prev_loss + loss
else:
# Handling with mask
loss = torch.sum(loss * submask) / submask_sum
loss = L2Wrap.apply(loss, logits, total_mask_sum, submask)
new_steps = prev_steps + submask_sum
new_loss = prev_loss * (prev_steps / new_steps) + loss * (
1 - prev_steps / new_steps)

return new_loss, new_shift_states, new_wkv_states, new_steps

total_loss = torch.tensor(
0, dtype=self.emb.weight.dtype).requires_grad_()
steps = 0
states = [
init_block_state(B, C, seq.device, self.emb.weight.dtype)
] * self.n_layer
states = BlockStateList.create(self.n_layer, B, C, seq.device,
self.emb.weight.dtype)
for i in range(math.ceil(T / self.ctx_len)):
if i != math.ceil(T / self.ctx_len) - 1:
total_loss, states, steps = deepspeed.checkpointing.checkpoint(
total_loss, new_shift_states, new_wkv_states, steps = deepspeed.checkpointing.checkpoint(
checkpointed_step,
idx[:, i * self.ctx_len:(i + 1) * self.ctx_len],
targets[:, i * self.ctx_len:(i + 1) * self.ctx_len],
seq_mask[:, i * self.ctx_len:(i + 1) * self.ctx_len],
total_loss,
states,
states.shift_states,
states.wkv_states,
steps,
)
else:
total_loss, states, steps = checkpointed_step(
total_loss, new_shift_states, new_wkv_states, steps = checkpointed_step(
idx[:, i * self.ctx_len:(i + 1) * self.ctx_len],
targets[:, i * self.ctx_len:(i + 1) * self.ctx_len],
seq_mask[:, i * self.ctx_len:(i + 1) * self.ctx_len],
total_loss,
states,
states.shift_states,
states.wkv_states,
steps,
)
states = BlockStateList(new_shift_states, new_wkv_states)
gc.collect()
# torch.cuda.empty_cache()

# Wandb logging only, if an active run exists
if wandb.run is not None:
wandb.log({'substep': batch_idx, 'real_ctx_len': T, 'train/loss': total_loss, 'trainer/global_step':self.global_step})
wandb.log({
'substep': batch_idx,
'real_ctx_len': T,
'train/loss': total_loss,
'trainer/global_step': self.global_step
})

return total_loss

Expand Down

0 comments on commit 6792641

Please sign in to comment.