From 679264117be25b071616e8544bc89d74f52f8da3 Mon Sep 17 00:00:00 2001 From: Blealtan Cao Date: Wed, 28 Jun 2023 17:00:09 +0800 Subject: [PATCH] Fix time checkpointing. The previous implementation misses all chunks backward pass but the last. Fixing this makes deepspeed stage 2 not working though. Also making the states really checkpointed through expanding classes into raw tensors. --- RWKV-v4neo/config-example.yaml | 3 +- RWKV-v4neo/src/model.py | 155 +++++++++++++++++++++------------ 2 files changed, 101 insertions(+), 57 deletions(-) diff --git a/RWKV-v4neo/config-example.yaml b/RWKV-v4neo/config-example.yaml index cd4225fc..064e180c 100644 --- a/RWKV-v4neo/config-example.yaml +++ b/RWKV-v4neo/config-example.yaml @@ -24,7 +24,8 @@ trainer: # For more details see: # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed-zero-stage-2 # - strategy: deepspeed_stage_2_offload + #!FIXME: currently only deepspeed_stage_1 is supported, due to that deepspeed cannot handle repeated backward hook. + strategy: deepspeed_stage_1 # Floating point precision for the model, because RWKV is built FOR bf16 # you should pretty much never change this setting diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 1aa5eb98..a454d02d 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -2,15 +2,17 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import os, math +import gc, math from random import randint from typing import List, Optional + import numpy as np import torch # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) import torch.nn as nn from torch.nn import functional as F + import lightning as L from lightning.pytorch.utilities import rank_zero_info, rank_zero_only from lightning.pytorch.strategies import DeepSpeedStrategy @@ -32,32 +34,56 @@ class TimeMixState: - def __init__(self, token_shift_state: torch.Tensor, - wkv_state: torch.Tensor): - self.token_shift_state = token_shift_state + def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor): + self.shift_state = shift_state self.wkv_state = wkv_state class ChannelMixState: - def __init__(self, token_shift_state: torch.Tensor): - self.token_shift_state = token_shift_state + def __init__(self, shift_state: torch.Tensor): + self.shift_state = shift_state class BlockState: - def __init__(self, time_mix_state: torch.Tensor, - channel_mix_state: torch.Tensor): + def __init__(self, time_mix_state: TimeMixState, + channel_mix_state: ChannelMixState): self.time_mix_state = time_mix_state self.channel_mix_state = channel_mix_state -def init_block_state(B, C, device, dtype): - wkv_state = torch.zeros((B, C, 3), device=device, dtype=torch.float) - wkv_state[:, :, -1] = -1e38 - token_shift_state = torch.zeros((B, C), device=device, dtype=dtype) - return BlockState(TimeMixState(token_shift_state, wkv_state), - ChannelMixState(token_shift_state)) +class BlockStateList: + + def __init__(self, shift_states, wkv_states): + self.wkv_states = wkv_states + self.shift_states = shift_states + + @staticmethod + def create(N, B, C, device, dtype): + result = BlockStateList.empty(N, B, C, device, dtype) + result.wkv_states[:] = 0 + result.wkv_states[:, :, :, -1] = -1e38 + result.shift_states[:] = 0 + return result + + @staticmethod + def empty(N, B, C, device, dtype): + wkv_states = torch.empty((N, B, C, 3), + device=device, + dtype=torch.float) + shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype) + return BlockStateList(shift_states, wkv_states) + + def __getitem__(self, layer: int): + return BlockState( + TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]), + ChannelMixState(self.shift_states[layer, 1])) + + def __setitem__(self, layer: int, state: BlockState): + self.shift_states[layer, 0] = state.time_mix_state.shift_state + self.wkv_states[layer] = state.time_mix_state.wkv_state + self.shift_states[layer, 1] = state.channel_mix_state.shift_state from torch.utils.cpp_extension import load @@ -109,8 +135,8 @@ def __init__(self, layer_id, n_layer, n_embd, dim_att): @MyFunction def forward(self, x, last_state: TimeMixState): # Mix x with the previous timestep to produce xk, xv, xr - xx = torch.concat( - (last_state.token_shift_state.unsqueeze(1), x[:, :-1]), dim=1) + xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]), + dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -148,8 +174,8 @@ def __init__(self, layer_id, n_layer, n_embd, dim_ffn): @MyFunction def forward(self, x, last_state: ChannelMixState): - xx = torch.concat( - (last_state.token_shift_state.unsqueeze(1), x[:, :-1]), dim=1) + xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]), + dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) @@ -213,7 +239,7 @@ def backward(ctx, grad_output): maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) - gy=gy*ctx.currentMask[:,None][None,:] + gy = gy * ctx.currentMask[:, None][None, :] return (grad_output, gy, None, None) @@ -367,26 +393,30 @@ def deepspeed_offload(self) -> bool: return "offload_optimizer" in cfg or "offload_parameters" in cfg return False - def forward(self, idx, last_states: List[BlockState]): + def forward(self, idx: torch.Tensor, last_shift_states: torch.Tensor, + last_wkv_states: torch.Tensor): B, T = idx.size() assert T <= self.ctx_len, "Cannot forward, model ctx_len is exhausted." x = self.emb(idx) - new_states = [] - for block, last_state in zip(self.blocks, last_states): + new_states = BlockStateList.empty(self.n_layer, B, self.n_embd, + x.device, x.dtype) + for i, (block, last_state) in enumerate( + zip(self.blocks, + BlockStateList(last_shift_states, last_wkv_states))): if self.grad_cp: x, new_state = deepspeed.checkpointing.checkpoint( block, x, last_state) else: x, new_state = block(x, last_state) - new_states.append(new_state) + new_states[i] = new_state x = self.ln_out(x) x = self.head(x) - return x, new_states + return x, new_states.shift_states, new_states.wkv_states def compute_loss(self, batch, batch_idx, do_cutoff: bool): seq = batch['input_ids'] @@ -396,14 +426,15 @@ def compute_loss(self, batch, batch_idx, do_cutoff: bool): # Check if attent mask is set, if not initialize it if seq_mask is None or seq_mask.ndim != 2: seq_mask = torch.ones_like(seq[:, 1:]) - + if do_cutoff: prev_step = 0 for step, len_cut in zip(self.ctx_len_warmup_steps, - self.ctx_len_cutoffs): - if prev_step <= self.global_step < step and len_cut < seq.shape[1] - 1: + self.ctx_len_cutoffs): + if prev_step <= self.global_step < step and len_cut < seq.shape[ + 1] - 1: pos = randint(0, seq.shape[1] - len_cut - 1) - + # Original # seq = seq[:, pos:pos + len_cut + 1] @@ -421,60 +452,72 @@ def compute_loss(self, batch, batch_idx, do_cutoff: bool): C = self.n_embd total_mask_sum = torch.sum(seq_mask) - def checkpointed_step(idx, targets, mask, prev_loss, last_states, - prev_steps): - logits, new_states = self(idx, last_states) + def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, + last_wkv_states, prev_steps): + logits, new_shift_states, new_wkv_states = self( + idx, last_shift_states, last_wkv_states) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), - targets.view(-1), reduction="none") + targets.view(-1), + reduction="none") submask = mask.view(-1)[:loss.shape[0]] - submask_sum=torch.sum(submask) + submask_sum = torch.sum(submask) - # Special handling of empty mask + # Special handling of empty mask # (possible when real_ctx_len is larger then ctx_len, which results into 'chunking') - if(submask_sum==0): + if submask_sum == 0: loss = torch.sum(loss * submask) / 1 loss = L2Wrap.apply(loss, logits, total_mask_sum, submask) - new_steps = prev_steps # + submask_sum - new_loss = prev_loss+loss - return new_loss, new_states, new_steps - - # Handling with mask - loss = torch.sum(loss * submask) / submask_sum - loss = L2Wrap.apply(loss, logits, total_mask_sum, submask) - new_steps = prev_steps + submask_sum - new_loss = prev_loss * (prev_steps / new_steps) + loss * ( - 1 - prev_steps / new_steps) - return new_loss, new_states, new_steps - - total_loss = torch.tensor(0, dtype=self.emb.weight.dtype) + new_steps = prev_steps # + submask_sum + new_loss = prev_loss + loss + else: + # Handling with mask + loss = torch.sum(loss * submask) / submask_sum + loss = L2Wrap.apply(loss, logits, total_mask_sum, submask) + new_steps = prev_steps + submask_sum + new_loss = prev_loss * (prev_steps / new_steps) + loss * ( + 1 - prev_steps / new_steps) + + return new_loss, new_shift_states, new_wkv_states, new_steps + + total_loss = torch.tensor( + 0, dtype=self.emb.weight.dtype).requires_grad_() steps = 0 - states = [ - init_block_state(B, C, seq.device, self.emb.weight.dtype) - ] * self.n_layer + states = BlockStateList.create(self.n_layer, B, C, seq.device, + self.emb.weight.dtype) for i in range(math.ceil(T / self.ctx_len)): if i != math.ceil(T / self.ctx_len) - 1: - total_loss, states, steps = deepspeed.checkpointing.checkpoint( + total_loss, new_shift_states, new_wkv_states, steps = deepspeed.checkpointing.checkpoint( checkpointed_step, idx[:, i * self.ctx_len:(i + 1) * self.ctx_len], targets[:, i * self.ctx_len:(i + 1) * self.ctx_len], seq_mask[:, i * self.ctx_len:(i + 1) * self.ctx_len], total_loss, - states, + states.shift_states, + states.wkv_states, steps, ) else: - total_loss, states, steps = checkpointed_step( + total_loss, new_shift_states, new_wkv_states, steps = checkpointed_step( idx[:, i * self.ctx_len:(i + 1) * self.ctx_len], targets[:, i * self.ctx_len:(i + 1) * self.ctx_len], seq_mask[:, i * self.ctx_len:(i + 1) * self.ctx_len], total_loss, - states, + states.shift_states, + states.wkv_states, steps, ) + states = BlockStateList(new_shift_states, new_wkv_states) + gc.collect() + # torch.cuda.empty_cache() # Wandb logging only, if an active run exists if wandb.run is not None: - wandb.log({'substep': batch_idx, 'real_ctx_len': T, 'train/loss': total_loss, 'trainer/global_step':self.global_step}) + wandb.log({ + 'substep': batch_idx, + 'real_ctx_len': T, + 'train/loss': total_loss, + 'trainer/global_step': self.global_step + }) return total_loss