Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRY: A modern repetition penalty that reliably prevents looping #5677

Merged
merged 13 commits into from
May 20, 2024
4 changes: 4 additions & 0 deletions extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class GenerationOptions(BaseModel):
top_a: float = 0
epsilon_cutoff: float = 0
eta_cutoff: float = 0
dry_multiplier: float = 0
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_sequence_breakers: str = '["\\n", ":", "\\"", "*"]'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not comfortable with this representation for the sequence breakers. I think that it should be processed in the same way as "Custom stopping strings" for consistency, without the [] list syntax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect that most clients will use a JSON library to build this value from some internal list representation. That library will output the format used here, no further processing required.

Telling developers "pass a JSON array" makes everything clear, including details like which quotation marks are valid, and how escape sequences work. "Pass something like a JSON array, but without the brackets" just sounds weird.

IMO, if anything, it is the stopping strings that should be changed to match this parameter.

guidance_scale: float = 1
negative_prompt: str = ''
penalty_alpha: float = 0
Expand Down
12 changes: 12 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def transformers_samplers():
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'tfs',
'top_a',
'repetition_penalty',
Expand Down Expand Up @@ -241,6 +245,10 @@ def transformers_samplers():
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'tfs',
'top_a',
'repetition_penalty',
Expand Down Expand Up @@ -299,6 +307,10 @@ def transformers_samplers():
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'tfs',
'top_a',
'repetition_penalty',
Expand Down
4 changes: 4 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def default_preset():
'top_a': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'dry_multiplier': 0,
'dry_base': 1.75,
'dry_allowed_length': 2,
'dry_sequence_breakers': '["\\n", ":", "\\"", "*"]',
'guidance_scale': 1,
'penalty_alpha': 0,
'mirostat_mode': 0,
Expand Down
105 changes: 98 additions & 7 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import math
import pprint

Expand Down Expand Up @@ -220,6 +221,74 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class DRYLogitsProcessor(LogitsProcessor):
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
self.allowed_length = allowed_length
self.multiplier = multiplier
self.base = base
self.sequence_breakers = sequence_breakers
self._range = _range

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self._range > 0:
input_ids = input_ids[:, -self._range:]

for input_ids_row, scores_row in zip(input_ids, scores):
# Raw integer must be extracted here to check for set membership.
last_token = input_ids_row[-1].item()

if last_token in self.sequence_breakers:
continue

# Exclude the last token as it always matches.
match_indices = (input_ids_row[:-1] == last_token).nonzero()

# Stores the maximum matching sequence length
# for each token immediately following the sequence in the input.
match_lengths = {}

for i in match_indices:
next_token = input_ids_row[i+1].item()

if next_token in self.sequence_breakers:
continue

# We have already found that `last_token` matches at this index,
# so the match is at least of length 1.
match_length = 1

# Extend the match backwards as far as possible.
while True:
j = i - match_length
if j < 0:
# Start of input reached.
break

previous_token = input_ids_row[-(match_length+1)].item()
if input_ids_row[j] != previous_token:
# Start of match reached.
break

if previous_token in self.sequence_breakers:
# Sequence-breaking token reached.
break

match_length += 1

if next_token in match_lengths:
match_lengths[next_token] = max(match_length, match_lengths[next_token])
else:
match_lengths[next_token] = match_length

# Apply penalties.
for token, match_length in match_lengths.items():
if match_length >= self.allowed_length:
penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
scores_row[token] -= penalty

return scores


class MirostatLogitsWarper(LogitsWarper):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]:
Expand Down Expand Up @@ -448,20 +517,38 @@ def custom_sort_key(obj):


def get_logits_processor_patch(self, **kwargs):
repetition_penalty = kwargs['generation_config'].repetition_penalty
presence_penalty = kwargs['generation_config'].presence_penalty
frequency_penalty = kwargs['generation_config'].frequency_penalty
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
generation_config = kwargs['generation_config']

do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
if do_rep_pen_hijack:
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created

result = self._get_logits_processor_old(**kwargs)

if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range)
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
generation_config.repetition_penalty,
generation_config.presence_penalty,
generation_config.frequency_penalty,
generation_config.repetition_penalty_range
)

if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
sequence_breaker_strings = json.loads(generation_config.dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}

result.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)

return result

Expand All @@ -477,6 +564,10 @@ def generation_config_init_patch(self, **kwargs):
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0)
self.dry_base = kwargs.pop("dry_base", 1.75)
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '["\\n", ":", "\\"", "*"]')
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size']:
if k in state:
generate_params[k] = state[k]

Expand Down
4 changes: 4 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def list_interface_input_elements():
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'repetition_penalty',
'presence_penalty',
'frequency_penalty',
Expand Down
6 changes: 6 additions & 0 deletions modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def create_ui(default_preset):
shared.gradio['presence_penalty'] = gr.Slider(0, 2, value=generate_params['presence_penalty'], step=0.05, label='presence_penalty')
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
with gr.Blocks():
shared.gradio['dry_multiplier'] = gr.Slider(0, 5, value=generate_params['dry_multiplier'], step=0.01, label='dry_multiplier')
shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base')
shared.gradio['dry_allowed_length'] = gr.Slider(1, 20, value=generate_params['dry_allowed_length'], step=1, label='dry_allowed_length')
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers')

gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)")

with gr.Column():
Expand Down