-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
779 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import pandas as pd | ||
from unitxt import get_logger | ||
from unitxt.api import evaluate, load_dataset | ||
from unitxt.inference import IbmGenAiInferenceEngine | ||
from unitxt.splitters import CloseTextSampler, FixedIndicesSampler, RandomSampler | ||
from unitxt.text_utils import print_dict | ||
|
||
logger = get_logger() | ||
|
||
# This examples evaluates different kinds of demo selection strategies on a classification task. | ||
# The different strategies are evaluates in 1,3,5 shots. The examples are selected from a demo pool of 100 examples. | ||
# RandomSampler - randomly sample a different set of examples for each test instance | ||
# CloseTextSampler - select the lexically closest amples from the demo pool for each test instance | ||
# FixedIndicesSampler - selec the same fixed set of demo examples for all instances | ||
|
||
card = "cards.ledgar" | ||
model_name = "google/flan-t5-xxl" | ||
inference_model = IbmGenAiInferenceEngine(model_name=model_name, max_new_tokens=32) | ||
|
||
|
||
df = pd.DataFrame(columns=["num_demos", "sampler", "f1_micro", "ci_low", "ci_high"]) | ||
|
||
for num_demos in [1, 3, 5]: | ||
for demo_sampler in [ | ||
RandomSampler(), | ||
CloseTextSampler(field="text"), | ||
FixedIndicesSampler(indices=[0, 1, 2, 4, 5]), | ||
]: | ||
dataset = load_dataset( | ||
card=card, | ||
template="templates.classification.multi_class.title", | ||
num_demos=num_demos, | ||
demos_pool_size=300, | ||
loader_limit=400, | ||
max_test_instances=200, | ||
sampler=demo_sampler, | ||
) | ||
|
||
test_dataset = dataset["test"] | ||
|
||
predictions = inference_model.infer(test_dataset) | ||
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset) | ||
|
||
logger.info( | ||
f"Sample input and output for sampler {demo_sampler} and num_demos '{num_demos}':" | ||
) | ||
print_dict( | ||
evaluated_dataset[0], | ||
keys_to_print=["source", "prediction", "processed_prediction"], | ||
) | ||
global_scores = evaluated_dataset[0]["score"]["global"] | ||
|
||
df.loc[len(df)] = [ | ||
num_demos, | ||
demo_sampler.to_json(), | ||
global_scores["score"], | ||
global_scores["score_ci_low"], | ||
global_scores["score_ci_high"], | ||
] | ||
|
||
df = df.round(decimals=2) | ||
logger.info(df.to_markdown()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from unitxt.blocks import Copy, LoadHF, Set, SplitRandomMix, TaskCard | ||
from unitxt.catalog import add_to_catalog | ||
from unitxt.test_utils.card import test_card | ||
|
||
# https://localizely.com/iso-639-2-list/ | ||
iso_lang_code_mapping = { | ||
"eng": "English", | ||
"afr": "Afrikaans", | ||
"amh": "Amharic", | ||
"ara": "Arabic", | ||
"hye": "Armenian", | ||
"asm": "Assamese", | ||
"ast": "Asturian", | ||
"azj": "Azerbaijani", | ||
"bel": "Belarusian", | ||
"ben": "Bengali", | ||
"bos": "Bosnian", | ||
"bul": "Bulgarian", | ||
"mya": "Burmese", | ||
"cat": "Catalan", | ||
"ceb": "Cebuano", | ||
"zho_simpl": "Chinese (Simplified)", | ||
"zho_trad": "Chinese (Traditional)", | ||
"hrv": "Croatian", | ||
"ces": "Czech", | ||
"dan": "Danish", | ||
"nld": "Dutch", | ||
"est": "Estonian", | ||
"tgl": "Tagalog", | ||
"fin": "Finnish", | ||
"fra": "French", | ||
"ful": "Fulah", | ||
"glg": "Galician", | ||
"lug": "Ganda", | ||
"kat": "Georgian", | ||
"deu": "German", | ||
"ell": "Greek", | ||
"guj": "Gujarati", | ||
"hau": "Hausa", | ||
"heb": "Hebrew", | ||
"hin": "Hindi", | ||
"hun": "Hungarian", | ||
"isl": "Icelandic", | ||
"ibo": "Igbo", | ||
"ind": "Indonesian", | ||
"gle": "Irish", | ||
"ita": "Italian", | ||
"jpn": "Japanese", | ||
"jav": "Javanese", | ||
"kea": "Kabuverdianu", | ||
"kam": "Kamba", | ||
"kan": "Kannada", | ||
"kaz": "Kazakh", | ||
"khm": "Khmer", | ||
"kor": "Korean", | ||
"kir": "Kyrgyz", | ||
"lao": "Lao", | ||
"lav": "Latvian", | ||
"lin": "Lingala", | ||
"lit": "Lithuanian", | ||
"luo": "Dholuo", | ||
"ltz": "Luxembourgish", | ||
"mkd": "Macedonian", | ||
"msa": "Malay", | ||
"mal": "Malayalam", | ||
"mlt": "Maltese", | ||
"mri": "Maori", | ||
"mar": "Marathi", | ||
"mon": "Mongolian", | ||
"npi": "Nepali", | ||
"nso": "Northern Sotho", | ||
"nob": "Norwegian Bokmål", | ||
"nya": "Nyanja", | ||
"oci": "Occitan", | ||
"ory": "Odia", | ||
"orm": "Oromo", | ||
"pus": "Pashto", | ||
"fas": "Persian", | ||
"pol": "Polish", | ||
"por": "Portuguese", | ||
"pan": "Punjabi", | ||
"ron": "Romanian", | ||
"rus": "Russian", | ||
"srp": "Serbian", | ||
"sna": "Shona", | ||
"snd": "Sindhi", | ||
"slk": "Slovak", | ||
"slv": "Slovenian", | ||
"som": "Somali", | ||
"ckb": "Sorani Kurdish", | ||
"spa": "Spanish", | ||
"swh": "Swahili", | ||
"swe": "Swedish", | ||
"tgk": "Tajik", | ||
"tam": "Tamil", | ||
"tel": "Telugu", | ||
"tha": "Thai", | ||
"tur": "Turkish", | ||
"ukr": "Ukrainian", | ||
"umb": "Umbundu", | ||
"urd": "Urdu", | ||
"uzb": "Uzbek", | ||
"vie": "Vietnamese", | ||
"cym": "Welsh", | ||
"wol": "Wolof", | ||
"xho": "Xhosa", | ||
"yor": "Yoruba", | ||
"zul": "Zulu", | ||
} | ||
|
||
|
||
langs_to_include = [ # langs currently supported by sacrebleu | ||
"ara", | ||
"fra", | ||
"deu", | ||
"jpn", | ||
"kor", | ||
"por", | ||
"ron", | ||
"spa", | ||
] | ||
|
||
langs = [ | ||
lang | ||
for lang in iso_lang_code_mapping.keys() | ||
if ("eng" not in lang and lang in langs_to_include) | ||
] | ||
pairs = [{"src": lang, "tgt": "eng"} for lang in langs] + [ | ||
{"src": "eng", "tgt": lang} for lang in langs | ||
] | ||
|
||
for pair in pairs: | ||
card = TaskCard( | ||
loader=LoadHF(path="gsarti/flores_101", name="all"), | ||
preprocess_steps=[ | ||
SplitRandomMix({"validation": "dev", "test": "devtest"}), | ||
Copy( | ||
field_to_field={ | ||
f"sentence_{pair['src']}": "text", | ||
f"sentence_{pair['tgt']}": "translation", | ||
}, | ||
), | ||
Set( | ||
fields={ | ||
"source_language": iso_lang_code_mapping[pair["src"]].lower(), | ||
"target_language": iso_lang_code_mapping[pair["tgt"]].lower(), | ||
} | ||
), | ||
], | ||
task="tasks.translation.directed", | ||
templates="templates.translation.directed.all", | ||
) | ||
|
||
test_card(card, demos_taken_from="test") | ||
add_to_catalog( | ||
card, f"cards.mt.flores_101.{pair['src']}_{pair['tgt']}", overwrite=True | ||
) | ||
|
||
if __name__ == "__main__": | ||
from unitxt import load_dataset | ||
|
||
ds = load_dataset( | ||
"card=cards.mt.flores_101.eng_deu,template_card_index=0", | ||
) | ||
|
||
ds["test"][0] |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
{ | ||
"__type__": "task_card", | ||
"loader": { | ||
"__type__": "load_hf", | ||
"path": "gsarti/flores_101", | ||
"name": "all" | ||
}, | ||
"preprocess_steps": [ | ||
{ | ||
"__type__": "split_random_mix", | ||
"mix": { | ||
"validation": "dev", | ||
"test": "devtest" | ||
} | ||
}, | ||
{ | ||
"__type__": "copy", | ||
"field_to_field": { | ||
"sentence_ara": "text", | ||
"sentence_eng": "translation" | ||
} | ||
}, | ||
{ | ||
"__type__": "set", | ||
"fields": { | ||
"source_language": "arabic", | ||
"target_language": "english" | ||
} | ||
} | ||
], | ||
"task": "tasks.translation.directed", | ||
"templates": "templates.translation.directed.all" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
{ | ||
"__type__": "task_card", | ||
"loader": { | ||
"__type__": "load_hf", | ||
"path": "gsarti/flores_101", | ||
"name": "all" | ||
}, | ||
"preprocess_steps": [ | ||
{ | ||
"__type__": "split_random_mix", | ||
"mix": { | ||
"validation": "dev", | ||
"test": "devtest" | ||
} | ||
}, | ||
{ | ||
"__type__": "copy", | ||
"field_to_field": { | ||
"sentence_deu": "text", | ||
"sentence_eng": "translation" | ||
} | ||
}, | ||
{ | ||
"__type__": "set", | ||
"fields": { | ||
"source_language": "german", | ||
"target_language": "english" | ||
} | ||
} | ||
], | ||
"task": "tasks.translation.directed", | ||
"templates": "templates.translation.directed.all" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
{ | ||
"__type__": "task_card", | ||
"loader": { | ||
"__type__": "load_hf", | ||
"path": "gsarti/flores_101", | ||
"name": "all" | ||
}, | ||
"preprocess_steps": [ | ||
{ | ||
"__type__": "split_random_mix", | ||
"mix": { | ||
"validation": "dev", | ||
"test": "devtest" | ||
} | ||
}, | ||
{ | ||
"__type__": "copy", | ||
"field_to_field": { | ||
"sentence_eng": "text", | ||
"sentence_ara": "translation" | ||
} | ||
}, | ||
{ | ||
"__type__": "set", | ||
"fields": { | ||
"source_language": "english", | ||
"target_language": "arabic" | ||
} | ||
} | ||
], | ||
"task": "tasks.translation.directed", | ||
"templates": "templates.translation.directed.all" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
{ | ||
"__type__": "task_card", | ||
"loader": { | ||
"__type__": "load_hf", | ||
"path": "gsarti/flores_101", | ||
"name": "all" | ||
}, | ||
"preprocess_steps": [ | ||
{ | ||
"__type__": "split_random_mix", | ||
"mix": { | ||
"validation": "dev", | ||
"test": "devtest" | ||
} | ||
}, | ||
{ | ||
"__type__": "copy", | ||
"field_to_field": { | ||
"sentence_eng": "text", | ||
"sentence_deu": "translation" | ||
} | ||
}, | ||
{ | ||
"__type__": "set", | ||
"fields": { | ||
"source_language": "english", | ||
"target_language": "german" | ||
} | ||
} | ||
], | ||
"task": "tasks.translation.directed", | ||
"templates": "templates.translation.directed.all" | ||
} |
Oops, something went wrong.