Skip to content

Commit

Permalink
Merge pull request #111 from Stability-AI/stratified-sampling-few-shot
Browse files Browse the repository at this point in the history
Add option to do stratified sampling for few-shot examples
  • Loading branch information
mrorii committed Oct 25, 2023
2 parents d209b9a + b46ea45 commit 9b42d41
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 7 deletions.
80 changes: 75 additions & 5 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from collections import defaultdict
from typing import Iterable
import numpy as np
import random
Expand Down Expand Up @@ -430,6 +431,8 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
self.download(data_dir, cache_dir, download_mode)
self._training_docs = None
self._fewshot_docs = None
self._target_to_docs = None
self._target_to_ratio = None

def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset.
Expand Down Expand Up @@ -515,11 +518,63 @@ def _process_doc(self, doc):
"""
return doc

def fewshot_examples(self, k, rnd):
def fewshot_examples(self, k, rnd, stratified=False):
"""Returns few shot examples from training docs"""
if self._training_docs is None:
self._training_docs = list(self.training_docs())

return rnd.sample(self._training_docs, k)
if stratified:
return self._stratified_fewshot_examples(self._training_docs, k, rnd)
else:
return rnd.sample(self._training_docs, k)

def _stratified_fewshot_examples(self, docs, k, rnd):
"""Returns few shot examples from `docs` with stratified sampling,
using the target from `self.doc_to_target` as the stratum.
WARNING: in order to speed up computation, this method caches the following
based on `docs`:
- `self._target_to_docs`, which stores a mapping from target to docs, and
- `self._target_to_ratio`, which stores a mapping from target to the ratio of docs
Thus, `docs` MUST be constant across different method calls.
This assumption should generally hold true, since for a given task `docs`
will typically be either one of:
- `self._training_docs` if the dataset for the task has training data, or
- `self._fewshot_docs` if the dataset for the task does not have any training data
"""
if self._target_to_docs is None or self._target_to_ratio is None:
self._target_to_docs = defaultdict(list)
for doc in docs:
target = self.doc_to_target(doc)
self._target_to_docs[target].append(doc)

self._target_to_ratio = {
target: len(_docs) / len(docs)
for target, _docs in self._target_to_docs.items()
}

# `k` should generally be constant across different method calls
# (as the number of few-shot is typically fixed for a given task),
# but this may not be guaranteed, so calculate the number of sample
# for each target per method call
target_to_num_samples = {
target: int(ratio * k) for target, ratio in self._target_to_ratio.items()
}
# Handle any rounding discrepancies by adjusting the counts
remaining_samples = k - sum(target_to_num_samples.values())
if remaining_samples > 0:
for _ in range(remaining_samples):
# Increment the min value
target = min(target_to_num_samples, key=target_to_num_samples.get)
target_to_num_samples[target] += 1

samples = []
for target, num_samples in target_to_num_samples.items():
samples.extend(rnd.sample(self._target_to_docs[target], num_samples))
# Randomly shuffle the samples to prevent potential biases
# that may arise from a fixed ordering of the targets
rnd.shuffle(samples)
return samples

def doc_to_decontamination_query(self, doc):
print(
Expand Down Expand Up @@ -592,7 +647,13 @@ def fewshot_description(self):

@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
self,
doc,
num_fewshot,
provide_description=None,
rnd=None,
description=None,
stratified=False,
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
Expand All @@ -608,6 +669,8 @@ def fewshot_context(
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:param stratified: bool
When true, does stratified sampling, using the target from `self.doc_to_target` as the stratum.
:returns: str
The fewshot context.
"""
Expand Down Expand Up @@ -643,7 +706,9 @@ def fewshot_context(
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotex = self.fewshot_examples(
k=num_fewshot, rnd=rnd, stratified=stratified
)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
Expand All @@ -652,7 +717,12 @@ def fewshot_context(
else self.test_docs()
)

fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
if stratified:
fewshotex = self._stratified_fewshot_examples(
self._fewshot_docs, num_fewshot + 1, rnd=rnd
)
else:
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)

# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
Expand Down
16 changes: 15 additions & 1 deletion lm_eval/tasks/ja/jcola.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class JCoLA(CoLA):
VERSION = 0.1
VERSION = 0.2
PROMPT_VERSION = 0.0
DATASET_PATH = "shunk031/JGLUE"
DATASET_NAME = "JCoLA"
Expand All @@ -43,6 +43,20 @@ def construct_requests(self, doc, ctx):
ll_false, _ = rf.loglikelihood(ctx, " %s" % self.CHOICES[0])
return ll_true, ll_false

def fewshot_context(
self,
doc,
num_fewshot,
provide_description=None,
rnd=None,
description=None,
stratified=False,
):
# Use stratified sampling
return super().fewshot_context(
doc, num_fewshot, provide_description, rnd, description, stratified=True
)


class JCoLAWithJAAlpacaPrompt(JCoLA):
"""
Expand Down
23 changes: 22 additions & 1 deletion lm_eval/tasks/ja/jnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class JNLIWithFintanPrompt(BalancedMultipleChoiceTask):
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""

VERSION = 1.2
VERSION = 1.3
PROMPT_VERSION = 0.2
DATASET_PATH = "shunk031/JGLUE"
DATASET_NAME = "JNLI"
Expand Down Expand Up @@ -92,6 +92,27 @@ def construct_requests(self, doc, ctx):
lls.append(rf.greedy_until(ctx, [self.SEP]))
return lls

def fewshot_context(
self,
doc,
num_fewshot,
provide_description=None,
rnd=None,
description=None,
stratified=False,
):
"""
TODO: move this to `MultipleChoiceTask`.
Directly implementing this in `MultipleChoiceTask` will break the task versioning
as the metric definition will get updated, and thus we need to incrementally apply this to all
tasks that inherit `MultipleChoiceTask` AND bump their task `VERSION`, and
only after all tasks have been updated, then we can move this to `MultipleChoiceTask`.
"""
# Use stratified sampling
return super().fewshot_context(
doc, num_fewshot, provide_description, rnd, description, stratified=True
)


class JNLIWithJAAlpacaPrompt(JNLIWithFintanPrompt):
"""
Expand Down

0 comments on commit 9b42d41

Please sign in to comment.