Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pose context length ext #1567

Merged
merged 8 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,17 @@ class Config:
unfrozen_parameters: Optional[List[str]] = None

sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None

# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
pose_max_context_len: Optional[int] = None
pose_num_chunks: Optional[int] = None

pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
default=True,
Expand Down
105 changes: 100 additions & 5 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Module containing the Trainer class and related functions"""
import math
import os
import random
from contextlib import contextmanager
from functools import partial
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -98,17 +99,87 @@ def add_position_ids(sample):
return sample


def add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids: Optional[List[int]] = None,
chunks: int = 2,
):
"""
use the PoSE technique to extend the context length by randomly skipping
positions in the context. We only want to skip right before tokens in
the split_on_token_ids list. We should attempt to randomly distribute
the skips, but we don't need the final position_ids to be the full
context_len. There may be multiple turns in the context, so we want to
make sure we take into account the maximum possible number of skips
remaining in each sample.
"""

input_ids = sample["input_ids"]
sample_len = len(input_ids)
max_skips = max_context_len - sample_len

if split_on_token_ids is None:
split_on_token_ids = []

if split_on_token_ids:
split_indices = [
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
]
else:
split_indices = [sample_len // chunks]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is going to work for any n_chunks > 2 right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, you're right

split_indices.append(len(input_ids)) # make sure we go to the end of the sample
if split_indices[0] < 2:
# drop the first split index if it's too close to the beginning
split_indices = split_indices[1:]

position_ids = []
prev_index = 0
total_skips = 0

for split_index in split_indices:
num_skips = (
random.randint(0, max_skips) # nosec B311
if prev_index != 0 and max_skips
else 0
)
max_skips -= num_skips
total_skips += num_skips

segment_position_ids = list(
range(prev_index + total_skips, split_index + total_skips)
)

position_ids.extend(segment_position_ids)
prev_index = split_index

position_ids = torch.tensor(position_ids)

sample["position_ids"] = position_ids
sample["length"] = len(position_ids)
assert len(position_ids) == len(input_ids)

return sample


def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample


def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
Expand Down Expand Up @@ -153,7 +224,31 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Group By Length",
)

if cfg.sample_packing:
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
Expand Down
Loading