From f17c9e2812eb7c603667daac8905a7e98bf707a4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Apr 2024 09:12:30 -0400 Subject: [PATCH 1/8] PoSE wip --- .../config/models/input/v0_4_1/__init__.py | 4 + src/axolotl/utils/trainer.py | 83 ++++++++++++++++++- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d99155ac25..6b5f6bd22b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -507,6 +507,10 @@ class Config: eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None + # for PoSE context length extension + use_pose: Optional[bool] = None + pose_split_on_token_ids: Optional[List[int]] = None + pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( default=True, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 808fbb59f5..7f484bdb66 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,9 +1,10 @@ """Module containing the Trainer class and related functions""" import math import os +import random from contextlib import contextmanager from functools import partial -from typing import List +from typing import List, Optional import numpy as np import torch @@ -98,6 +99,64 @@ def add_position_ids(sample): return sample +def add_pose_position_ids( + sample, max_context_len=32768, split_on_token_ids: Optional[List[int]] = None +): + """ + use the PoSE technique to extend the context length by randomly skipping + positions in the context. We only want to skip right before tokens in + the split_on_token_ids list. We should attempt to randomly distribute + the skips, but we don't need the final position_ids to be the full + context_len. There may be multiple turns in the context, so we want to + make sure we take into account the maximum possible number of skips + remaining in each sample. + """ + + input_ids = sample["input_ids"] + sample_len = len(input_ids) + max_skips = max_context_len - sample_len + + if split_on_token_ids is None: + split_on_token_ids = [] + + split_indices = [ + i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids + ] + if split_indices[0] < 2: + # drop the first split index if it's too close to the beginning + split_indices = split_indices[1:] + + if len(split_indices) == 0: + position_ids = torch.arange(sample_len) + else: + position_ids = [] + prev_index = -1 + total_skips = 0 + + for split_index in split_indices: + num_skips = ( + random.randint(0, max_skips) # nosec B311 + if prev_index != -1 and max_skips + else 0 + ) + max_skips -= num_skips + total_skips += num_skips + + segment_position_ids = list( + range(prev_index + 1 + total_skips, split_index + total_skips) + ) + + position_ids.extend(segment_position_ids) + prev_index = split_index + + position_ids = torch.tensor(position_ids) + + sample["position_ids"] = position_ids + sample["length"] = len(position_ids) + + return sample + + def add_length(sample): sample["length"] = len(sample["input_ids"]) return sample @@ -153,7 +212,27 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Group By Length", ) - if cfg.sample_packing: + if cfg.use_pose: + pose_fn = partial( + add_pose_position_ids, + max_context_len=cfg.sequence_len, + split_on_token_ids=cfg.pose_split_on_token_ids, + ) + train_dataset = train_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + elif cfg.sample_packing: train_dataset = train_dataset.map( add_position_ids, num_proc=cfg.dataset_processes, From aacdbc34daf367da0459fd64c42e47baf4edc4bd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Apr 2024 13:57:40 -0400 Subject: [PATCH 2/8] fixes for pose splitting --- src/axolotl/utils/trainer.py | 62 ++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 7f484bdb66..99ce774c55 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -100,7 +100,10 @@ def add_position_ids(sample): def add_pose_position_ids( - sample, max_context_len=32768, split_on_token_ids: Optional[List[int]] = None + sample, + max_context_len=32768, + split_on_token_ids: Optional[List[int]] = None, + chunks: int = 2, ): """ use the PoSE technique to extend the context length by randomly skipping @@ -119,40 +122,42 @@ def add_pose_position_ids( if split_on_token_ids is None: split_on_token_ids = [] - split_indices = [ - i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids - ] + if split_on_token_ids: + split_indices = [ + i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids + ] + else: + split_indices = [sample_len // chunks] + split_indices.append(len(input_ids)) # make sure we go to the end of the sample if split_indices[0] < 2: # drop the first split index if it's too close to the beginning split_indices = split_indices[1:] - if len(split_indices) == 0: - position_ids = torch.arange(sample_len) - else: - position_ids = [] - prev_index = -1 - total_skips = 0 - - for split_index in split_indices: - num_skips = ( - random.randint(0, max_skips) # nosec B311 - if prev_index != -1 and max_skips - else 0 - ) - max_skips -= num_skips - total_skips += num_skips + position_ids = [] + prev_index = 0 + total_skips = 0 - segment_position_ids = list( - range(prev_index + 1 + total_skips, split_index + total_skips) - ) + for split_index in split_indices: + num_skips = ( + random.randint(0, max_skips) # nosec B311 + if prev_index != 0 and max_skips + else 0 + ) + max_skips -= num_skips + total_skips += num_skips + + segment_position_ids = list( + range(prev_index + total_skips, split_index + total_skips) + ) - position_ids.extend(segment_position_ids) - prev_index = split_index + position_ids.extend(segment_position_ids) + prev_index = split_index - position_ids = torch.tensor(position_ids) + position_ids = torch.tensor(position_ids) sample["position_ids"] = position_ids sample["length"] = len(position_ids) + assert len(position_ids) == len(input_ids) return sample @@ -162,8 +167,11 @@ def add_length(sample): return sample -def drop_long_seq(sample, sequence_len=2048): - return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): + return ( + len(sample["input_ids"]) <= sequence_len + and len(sample["input_ids"]) >= min_sequence_len + ) def process_datasets_for_packing(cfg, train_dataset, eval_dataset): From 87007846beb1dc1f407921c28e634e105db9826b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 24 Apr 2024 10:54:58 -0400 Subject: [PATCH 3/8] set pose context len so we can pick that up seperately from the usable training context len --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/trainer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 6b5f6bd22b..b5727d9629 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -510,6 +510,7 @@ class Config: # for PoSE context length extension use_pose: Optional[bool] = None pose_split_on_token_ids: Optional[List[int]] = None + pose_max_context_len: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 99ce774c55..80df4ccd24 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -223,7 +223,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if cfg.use_pose: pose_fn = partial( add_pose_position_ids, - max_context_len=cfg.sequence_len, + max_context_len=cfg.pose_max_context_len, split_on_token_ids=cfg.pose_split_on_token_ids, ) train_dataset = train_dataset.map( From cd089f9819f36d23832b9f1a6baa8d2f69ff8577 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Apr 2024 16:02:55 -0400 Subject: [PATCH 4/8] support min sample len and define num chunks --- .../utils/config/models/input/v0_4_1/__init__.py | 2 ++ src/axolotl/utils/trainer.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index b5727d9629..0fb794ba32 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -503,6 +503,7 @@ class Config: unfrozen_parameters: Optional[List[str]] = None sequence_len: int = Field(default=512) + min_sample_len: Optional[int] = None sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None @@ -511,6 +512,7 @@ class Config: use_pose: Optional[bool] = None pose_split_on_token_ids: Optional[List[int]] = None pose_max_context_len: Optional[int] = None + pose_num_chunks: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 80df4ccd24..95b6f11be9 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -175,7 +175,11 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + drop_long = partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len or 2, + ) with zero_first(is_main_process()): if cfg.is_preprocess: min_input_len = np.min(get_dataset_lengths(train_dataset)) @@ -221,10 +225,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): ) if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks pose_fn = partial( add_pose_position_ids, max_context_len=cfg.pose_max_context_len, split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, ) train_dataset = train_dataset.map( pose_fn, From 6ffa083cff954bd49b5b55c0591f0ce0e8a3372b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Apr 2024 19:53:51 -0400 Subject: [PATCH 5/8] fix chunk splitting --- src/axolotl/utils/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 95b6f11be9..ceb2d4ced8 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -127,7 +127,8 @@ def add_pose_position_ids( i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids ] else: - split_indices = [sample_len // chunks] + chunk_len = sample_len // chunks + split_indices = [i * chunk_len for i in range(1, chunks)] split_indices.append(len(input_ids)) # make sure we go to the end of the sample if split_indices[0] < 2: # drop the first split index if it's too close to the beginning From 43c4c97e058346039f768dca63fe03b3c81980c9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Apr 2024 20:07:58 -0400 Subject: [PATCH 6/8] support for curriculum/ordered learning with pose --- src/axolotl/core/trainer_builder.py | 7 +++++++ src/axolotl/utils/trainer.py | 1 + 2 files changed, 8 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bddb95740..09651bdc9b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "path under the model to access the layers"}, ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) class AxolotlTrainer(Trainer): @@ -347,6 +351,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) + if self.args.curriculum_sampling: + return SequentialSampler(self.train_dataset) return super()._get_train_sampler() def _get_eval_sampler( @@ -1193,6 +1199,7 @@ def build(self, total_num_steps): False if self.cfg.ddp else None ) training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length + training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling report_to = None if self.cfg.use_wandb: report_to = "wandb" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ceb2d4ced8..656d511d42 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -241,6 +241,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) + train_dataset = train_dataset.sort(lambda x: x["position_ids"][-1]) if cfg.eval_sample_packing is not False: if eval_dataset: eval_dataset = eval_dataset.map( From 2f45a04fa18b96f55c7f51957bb8089e77f5fa5c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Apr 2024 20:10:39 -0400 Subject: [PATCH 7/8] fix sequence len sort --- src/axolotl/utils/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 656d511d42..2e3728cc8a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -154,6 +154,7 @@ def add_pose_position_ids( position_ids.extend(segment_position_ids) prev_index = split_index + sample["sequence_len"] = position_ids[-1] position_ids = torch.tensor(position_ids) sample["position_ids"] = position_ids @@ -241,7 +242,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - train_dataset = train_dataset.sort(lambda x: x["position_ids"][-1]) + train_dataset = train_dataset.sort("sequence_len") if cfg.eval_sample_packing is not False: if eval_dataset: eval_dataset = eval_dataset.map( From 74d1284ee09d205ced178ef8ad9120701086858e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Apr 2024 20:16:17 -0400 Subject: [PATCH 8/8] add curriculum_sampling to pydantic --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0fb794ba32..e27a8ddd52 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -507,6 +507,7 @@ class Config: sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None + curriculum_sampling: Optional[bool] = None # for PoSE context length extension use_pose: Optional[bool] = None