Skip to content

Commit

Permalink
Merge branch 'main' into prep-times
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel authored Jul 23, 2024
2 parents 657061d + 4a62f7d commit cbf8537
Show file tree
Hide file tree
Showing 25 changed files with 779 additions and 3 deletions.
12 changes: 12 additions & 0 deletions docs/docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ Demonstrates how different formats and system prompts effect the input provided

Related documentation: :ref:`Formatting tutorial <adding_format>`.

Evaluate the impact of different demonstration example selections
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Demonstrates how different methods of selecting the demonstrations in in-context learning affect the results.
Three methods are considered: fixed selection of example demonstrations for all test instance,
random selection of example demonstrations for each test instance,
and choosing the demonstration examples most (lexically) similar to each test instance.

`Example code <https://github.com/IBM/unitxt/blob/main/examples/evaluate_different_demo_selections.py>`_

Related documentation: :ref:`Formatting tutorial <adding_format>`.

LLM as Judges
--------------

Expand Down
62 changes: 62 additions & 0 deletions examples/evaluate_different_demo_selections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pandas as pd
from unitxt import get_logger
from unitxt.api import evaluate, load_dataset
from unitxt.inference import IbmGenAiInferenceEngine
from unitxt.splitters import CloseTextSampler, FixedIndicesSampler, RandomSampler
from unitxt.text_utils import print_dict

logger = get_logger()

# This examples evaluates different kinds of demo selection strategies on a classification task.
# The different strategies are evaluates in 1,3,5 shots. The examples are selected from a demo pool of 100 examples.
# RandomSampler - randomly sample a different set of examples for each test instance
# CloseTextSampler - select the lexically closest amples from the demo pool for each test instance
# FixedIndicesSampler - selec the same fixed set of demo examples for all instances

card = "cards.ledgar"
model_name = "google/flan-t5-xxl"
inference_model = IbmGenAiInferenceEngine(model_name=model_name, max_new_tokens=32)


df = pd.DataFrame(columns=["num_demos", "sampler", "f1_micro", "ci_low", "ci_high"])

for num_demos in [1, 3, 5]:
for demo_sampler in [
RandomSampler(),
CloseTextSampler(field="text"),
FixedIndicesSampler(indices=[0, 1, 2, 4, 5]),
]:
dataset = load_dataset(
card=card,
template="templates.classification.multi_class.title",
num_demos=num_demos,
demos_pool_size=300,
loader_limit=400,
max_test_instances=200,
sampler=demo_sampler,
)

test_dataset = dataset["test"]

predictions = inference_model.infer(test_dataset)
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset)

logger.info(
f"Sample input and output for sampler {demo_sampler} and num_demos '{num_demos}':"
)
print_dict(
evaluated_dataset[0],
keys_to_print=["source", "prediction", "processed_prediction"],
)
global_scores = evaluated_dataset[0]["score"]["global"]

df.loc[len(df)] = [
num_demos,
demo_sampler.to_json(),
global_scores["score"],
global_scores["score_ci_low"],
global_scores["score_ci_high"],
]

df = df.round(decimals=2)
logger.info(df.to_markdown())
166 changes: 166 additions & 0 deletions prepare/cards/translation/flores101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from unitxt.blocks import Copy, LoadHF, Set, SplitRandomMix, TaskCard
from unitxt.catalog import add_to_catalog
from unitxt.test_utils.card import test_card

# https://localizely.com/iso-639-2-list/
iso_lang_code_mapping = {
"eng": "English",
"afr": "Afrikaans",
"amh": "Amharic",
"ara": "Arabic",
"hye": "Armenian",
"asm": "Assamese",
"ast": "Asturian",
"azj": "Azerbaijani",
"bel": "Belarusian",
"ben": "Bengali",
"bos": "Bosnian",
"bul": "Bulgarian",
"mya": "Burmese",
"cat": "Catalan",
"ceb": "Cebuano",
"zho_simpl": "Chinese (Simplified)",
"zho_trad": "Chinese (Traditional)",
"hrv": "Croatian",
"ces": "Czech",
"dan": "Danish",
"nld": "Dutch",
"est": "Estonian",
"tgl": "Tagalog",
"fin": "Finnish",
"fra": "French",
"ful": "Fulah",
"glg": "Galician",
"lug": "Ganda",
"kat": "Georgian",
"deu": "German",
"ell": "Greek",
"guj": "Gujarati",
"hau": "Hausa",
"heb": "Hebrew",
"hin": "Hindi",
"hun": "Hungarian",
"isl": "Icelandic",
"ibo": "Igbo",
"ind": "Indonesian",
"gle": "Irish",
"ita": "Italian",
"jpn": "Japanese",
"jav": "Javanese",
"kea": "Kabuverdianu",
"kam": "Kamba",
"kan": "Kannada",
"kaz": "Kazakh",
"khm": "Khmer",
"kor": "Korean",
"kir": "Kyrgyz",
"lao": "Lao",
"lav": "Latvian",
"lin": "Lingala",
"lit": "Lithuanian",
"luo": "Dholuo",
"ltz": "Luxembourgish",
"mkd": "Macedonian",
"msa": "Malay",
"mal": "Malayalam",
"mlt": "Maltese",
"mri": "Maori",
"mar": "Marathi",
"mon": "Mongolian",
"npi": "Nepali",
"nso": "Northern Sotho",
"nob": "Norwegian Bokmål",
"nya": "Nyanja",
"oci": "Occitan",
"ory": "Odia",
"orm": "Oromo",
"pus": "Pashto",
"fas": "Persian",
"pol": "Polish",
"por": "Portuguese",
"pan": "Punjabi",
"ron": "Romanian",
"rus": "Russian",
"srp": "Serbian",
"sna": "Shona",
"snd": "Sindhi",
"slk": "Slovak",
"slv": "Slovenian",
"som": "Somali",
"ckb": "Sorani Kurdish",
"spa": "Spanish",
"swh": "Swahili",
"swe": "Swedish",
"tgk": "Tajik",
"tam": "Tamil",
"tel": "Telugu",
"tha": "Thai",
"tur": "Turkish",
"ukr": "Ukrainian",
"umb": "Umbundu",
"urd": "Urdu",
"uzb": "Uzbek",
"vie": "Vietnamese",
"cym": "Welsh",
"wol": "Wolof",
"xho": "Xhosa",
"yor": "Yoruba",
"zul": "Zulu",
}


langs_to_include = [ # langs currently supported by sacrebleu
"ara",
"fra",
"deu",
"jpn",
"kor",
"por",
"ron",
"spa",
]

langs = [
lang
for lang in iso_lang_code_mapping.keys()
if ("eng" not in lang and lang in langs_to_include)
]
pairs = [{"src": lang, "tgt": "eng"} for lang in langs] + [
{"src": "eng", "tgt": lang} for lang in langs
]

for pair in pairs:
card = TaskCard(
loader=LoadHF(path="gsarti/flores_101", name="all"),
preprocess_steps=[
SplitRandomMix({"validation": "dev", "test": "devtest"}),
Copy(
field_to_field={
f"sentence_{pair['src']}": "text",
f"sentence_{pair['tgt']}": "translation",
},
),
Set(
fields={
"source_language": iso_lang_code_mapping[pair["src"]].lower(),
"target_language": iso_lang_code_mapping[pair["tgt"]].lower(),
}
),
],
task="tasks.translation.directed",
templates="templates.translation.directed.all",
)

test_card(card, demos_taken_from="test")
add_to_catalog(
card, f"cards.mt.flores_101.{pair['src']}_{pair['tgt']}", overwrite=True
)

if __name__ == "__main__":
from unitxt import load_dataset

ds = load_dataset(
"card=cards.mt.flores_101.eng_deu,template_card_index=0",
)

ds["test"][0]
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.params.Depends", "fastapi.
"src".msg = "Use unitxt outside src/ and relative imports inside src/ and install unitxt from source with `pip install -e '.[dev]'`."

[tool.codespell]
ignore-words-list = 'rouge,ot,ans,nd,cann'
ignore-words-list = 'rouge,ot,ans,nd,cann,som,tha,vie'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
skip = '*cards/trec*,*cards/belebele*,*cards/amazon_mass*,*cards/reuters21578*,*cards/attaq_500*,*cards/cohere_for_ai*,*egg-info*,*/logs/*'
skip = '*cards/mt/flores101*,*cards/trec*,*cards/belebele*,*cards/amazon_mass*,*cards/reuters21578*,*cards/attaq_500*,*cards/cohere_for_ai*,*egg-info*,*/logs/*'
33 changes: 33 additions & 0 deletions src/unitxt/catalog/cards/mt/flores_101/ara_eng.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "gsarti/flores_101",
"name": "all"
},
"preprocess_steps": [
{
"__type__": "split_random_mix",
"mix": {
"validation": "dev",
"test": "devtest"
}
},
{
"__type__": "copy",
"field_to_field": {
"sentence_ara": "text",
"sentence_eng": "translation"
}
},
{
"__type__": "set",
"fields": {
"source_language": "arabic",
"target_language": "english"
}
}
],
"task": "tasks.translation.directed",
"templates": "templates.translation.directed.all"
}
33 changes: 33 additions & 0 deletions src/unitxt/catalog/cards/mt/flores_101/deu_eng.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "gsarti/flores_101",
"name": "all"
},
"preprocess_steps": [
{
"__type__": "split_random_mix",
"mix": {
"validation": "dev",
"test": "devtest"
}
},
{
"__type__": "copy",
"field_to_field": {
"sentence_deu": "text",
"sentence_eng": "translation"
}
},
{
"__type__": "set",
"fields": {
"source_language": "german",
"target_language": "english"
}
}
],
"task": "tasks.translation.directed",
"templates": "templates.translation.directed.all"
}
33 changes: 33 additions & 0 deletions src/unitxt/catalog/cards/mt/flores_101/eng_ara.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "gsarti/flores_101",
"name": "all"
},
"preprocess_steps": [
{
"__type__": "split_random_mix",
"mix": {
"validation": "dev",
"test": "devtest"
}
},
{
"__type__": "copy",
"field_to_field": {
"sentence_eng": "text",
"sentence_ara": "translation"
}
},
{
"__type__": "set",
"fields": {
"source_language": "english",
"target_language": "arabic"
}
}
],
"task": "tasks.translation.directed",
"templates": "templates.translation.directed.all"
}
33 changes: 33 additions & 0 deletions src/unitxt/catalog/cards/mt/flores_101/eng_deu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "gsarti/flores_101",
"name": "all"
},
"preprocess_steps": [
{
"__type__": "split_random_mix",
"mix": {
"validation": "dev",
"test": "devtest"
}
},
{
"__type__": "copy",
"field_to_field": {
"sentence_eng": "text",
"sentence_deu": "translation"
}
},
{
"__type__": "set",
"fields": {
"source_language": "english",
"target_language": "german"
}
}
],
"task": "tasks.translation.directed",
"templates": "templates.translation.directed.all"
}
Loading

0 comments on commit cbf8537

Please sign in to comment.