diff --git a/README.md b/README.md
index b256fb88..71606405 100644
--- a/README.md
+++ b/README.md
@@ -19,18 +19,27 @@
---
## Annoucement
+
+- [2024-11] ๐๐ The `lmms-eval/v0.3.0` has been upgraded to support audio evaluations for audio models like Qwen2-Audio and Gemini_Audio across tasks such as AIR-Bench, Clotho-AQA, LibriSpeech, and more. Please refer to the [blog](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/main/docs/lmms-eval-0.3.md) for more details!
+
+- [2024-07] ๐๐ We have released the [technical report](https://arxiv.org/abs/2407.12772) and [LiveBench](https://huggingface.co/spaces/lmms-lab/LiveBench)!
+
+- [2024-06] ๐ฌ๐ฌ The `lmms-eval/v0.2.0` has been upgraded to support video evaluations for video models like LLaVA-NeXT Video and Gemini 1.5 Pro across tasks such as EgoSchema, PerceptionTest, VideoMME, and more. Please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for more details!
+
+- [2024-03] ๐๐ We have released the first version of `lmms-eval`, please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.1/) for more details!
+
+
+We warmly welcome contributions from the open-source community! Below is a chronological list of recent tasks, models, and features added by our amazing contributors.
+
- [2024-10] ๐๐ We welcome the new task [NaturalBench](https://huggingface.co/datasets/BaiqiL/NaturalBench), a vision-centric VQA benchmark (NeurIPS'24) that challenges vision-language models with simple questions about natural imagery.
- [2024-10] ๐๐ We welcome the new task [TemporalBench](https://huggingface.co/datasets/microsoft/TemporalBench) for fine-grained temporal understanding and reasoning for videos, which reveals a huge (>30%) human-AI gap.
- [2024-10] ๐๐ We welcome the new tasks [VDC](https://rese1f.github.io/aurora-web/) for video detailed captioning, [MovieChat-1K](https://rese1f.github.io/MovieChat/) for long-form video understanding, and [Vinoground](https://vinoground.github.io/), a temporal counterfactual LMM benchmark composed of 1000 short natural video-caption pairs. We also welcome the new models: [AuroraCap](https://github.com/rese1f/aurora) and [MovieChat](https://github.com/rese1f/MovieChat).
- [2024-09] ๐๐ We welcome the new tasks [MMSearch](https://mmsearch.github.io/) and [MME-RealWorld](https://mme-realworld.github.io/) for inference acceleration
- [2024-09] โ๏ธ๏ธโ๏ธ๏ธ๏ธ๏ธ We upgrade `lmms-eval` to `0.2.3` with more tasks and features. We support a compact set of language tasks evaluations (code credit to [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)), and we remove the registration logic at start (for all models and tasks) to reduce the overhead. Now `lmms-eval` only launches necessary tasks/models. Please check the [release notes](https://github.com/EvolvingLMMs-Lab/lmms-eval/releases/tag/v0.2.3) for more details.
- [2024-08] ๐๐ We welcome the new model [LLaVA-OneVision](https://huggingface.co/papers/2408.03326), [Mantis](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/162), new tasks [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench), [LongVideoBench](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/117), [MMStar](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/158). We provide new feature of SGlang Runtime API for llava-onevision model, please refer the [doc](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/main/docs/commands.md) for inference acceleration
-- [2024-07] ๐๐ We have released the [technical report](https://arxiv.org/abs/2407.12772) and [LiveBench](https://huggingface.co/spaces/lmms-lab/LiveBench)!
- [2024-07] ๐จโ๐ป๐จโ๐ป The `lmms-eval/v0.2.1` has been upgraded to support more models, including [LongVA](https://github.com/EvolvingLMMs-Lab/LongVA), [InternVL-2](https://github.com/OpenGVLab/InternVL), [VILA](https://github.com/NVlabs/VILA), and many more evaluation tasks, e.g. [Details Captions](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/136), [MLVU](https://arxiv.org/abs/2406.04264), [WildVision-Bench](https://huggingface.co/datasets/WildVision/wildvision-arena-data), [VITATECS](https://github.com/lscpku/VITATECS) and [LLaVA-Interleave-Bench](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/).
-- [2024-06] ๐ฌ๐ฌ The `lmms-eval/v0.2.0` has been upgraded to support video evaluations for video models like LLaVA-NeXT Video and Gemini 1.5 Pro across tasks such as EgoSchema, PerceptionTest, VideoMME, and more. Please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for more details
-
-- [2024-03] ๐๐ We have released the first version of `lmms-eval`, please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.1/) for more details
+
## Why `lmms-eval`?
diff --git a/docs/lmms-eval-0.3.md b/docs/lmms-eval-0.3.md
new file mode 100644
index 00000000..b402b70c
--- /dev/null
+++ b/docs/lmms-eval-0.3.md
@@ -0,0 +1,265 @@
+# Integration of Audio Evaluation in LMMs-Eval
+
+
+## **Introduction**
+
+Humans perceive the world through both sight and sound, integrating visual cues with auditory signals such as speech, environmental sounds, and emotional tones.
+
+This dual sensory input enhances decision-making and overall understanding. Similarly, for multimodal models to achieve human-like comprehension, it is essential to make them process both visual and auditory data together.
+
+While many models have made progress in integrating audio understanding, there is still no reproducible and efficient evaluation toolkit to fairly assess their capabilities.
+
+To address this, we introduce an upgrade to the `lmms-eval` framework, focusing on audio understanding. Building on the success of `lmms-eval/v0.2.0`, the new `lmms-eval/v0.3.0` includes dedicated modules and designs for audio tasks, ensuring consistent evaluation across audio and visual modalities.
+
+This upgrade includes multiple benchmarks for audio understanding and instruction following, enabling standardized and reproducible comparisons of various audio models.
+
+## Audio Evaluation Pipeline
+
+1. **Improved Pipeline for Audio Evaluations**
+
+ Hereโs a breakdown of adding audio datasets support.
+
+ 1. **Load Audio:** Audios are saved in HuggingFace and can be loaded via the `doc_to_audio` function.
+ - The code specifically demonstrates the logic of how we handle audio datasets in lmms-eval.
+
+ ```python
+ def air_bench_doc_to_audio(doc):
+ return [doc["audio"]]
+ ```
+
+ 2. **Format questions:** Questions and instructions are defined in `/utils.py`. For some Audio Instruction Following (AIF) tasks, we create custom prompts and try to align with Qwen2-Audio's evaluation format since the default dataset instructions are sometimes not clear enough for some datasets. We can add model-specific prompts besides the default instruction.
+ - The code demonstrates an example of formatting the question.
+
+ ```python
+ # This is the place where you format your question
+ def common_voice_15_doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"
+ ```
+
+ 3. **Process results:** Model outputs are evaluated using metrics from either official dataset implementations or aligning with the implementation in [AudioBench](https://github.com/AudioLLMs/AudioBench). We primarily adopt three types of metrics:
+
+ **a. Accuracy:** Used for tasks with definitive ground truth answers, such as multiple-choice questions
+
+ **b. WER:** Applied to some Audio Speech Recognition (ASR) tasks.
+
+ **c. GPT-4 Eval:** Applied to open-ended responses. We align the evaluation prompt with the implementation in [AudioBench](https://github.com/AudioLLMs/AudioBench).
+
+ - The code specifically demonstrates an example prompt for GPT-4 Evaluation.
+
+ ```python
+ eval_prompt = """
+ [Question]
+ {question}
+
+ [Reference Answer]
+ {ground_truth}
+
+ [Model Answer]
+ {model_response}
+
+ [Task]
+ Rate the model's answer based on its alignment with the reference answer, focusing on accuracy and relevance to the reference provided. Please be critical on the details.
+ Criteria: Assess if the model's response mirrors the reference in terms of content, accuracy, and relevance.
+ Score0: The answer is completely misaligned, providing incorrect or irrelevant information compared to the reference.
+ Score1: The answer shows minimal alignment, often misunderstanding or providing irrelevant details unrelated to the reference.
+ Score2: The answer recognizes the topic but diverges significantly from the reference in accuracy or relevance.
+ Score3: The answer aligns with the reference generally but lacks detail or precise accuracy in some aspects.
+ Score4: The answer is mostly accurate and relevant, closely following the reference but could be clearer or more detailed.
+ Score5: The answer is highly accurate, detailed, and matches the reference answer perfectly, capturing its essence and detail.
+
+ Your response should be formatted as follows:
+ Explanation: (Provide a concise explanation of your rating, comparing the reference answer with the model's response. "The reference answer is [XXX], while the model's answer is [YYY]. I think ...")
+ Rating: (int)"""
+ ```
+
+
+
+
+ 4. **Aggregate results:**
+ After evaluating each data instance, we aggregate the individual results to generate the overall evaluation metrics. Finally, we provide a summary table that consolidates all the evaluation results, similar to the one in [Googleโs Gemini report](https://arxiv.org/abs/2312.11805).
+ 5. **Grouped Tasks:**
+ For tasks with multiple subsets, we group all subset tasks together. For example, the AirBench-Chat dataset includes 4 subsets: sound, music, speech, mixed. By running `--task air_bench_chat`, all 4 subsets can be evaluated together, eliminating the need to specify each subset individually. We summarize all the grouped task names in Table 1. This pipeline ensures a thorough and standardized evaluation process for Audio, facilitating consistent and reliable performance assessment across various tasks and datasets.
+ - The code specifically demonstrates an example yaml file of task grouping.
+
+ ```python
+ group: air_bench_chat
+ task:
+ - air_bench_chat_sound
+ - air_bench_chat_music
+ - air_bench_chat_speech
+ - air_bench_chat_mixed
+ ```
+
+
+2. **Audio-based Capabilities**
+
+ Our selected benchmarks collectively evaluate various essential audio-based capabilities, as inspired by [AudioBench](https://github.com/AudioLLMs/AudioBench):
+
+ 1. **Audio Captioning:** The ability to accurately transcribe human speech and convert audio content into text.
+ 2. **Speech Understanding:** The capability to comprehend the semantic meaning of human speech, enabling appropriate responses to questions and audio instructions.
+ 3. **Audio Scene Understanding:** The ability to interpret non-human sounds, such as environment sounds.
+ 4. **Voice Understanding:** The capability to analyze non-speech human vocal information, including emotional states, accents, and speaker characteristics.
+ 5. **Specialized Audio Processing:** The ability to analyze other audio types, such as musical compositions and multilingual content.
+
+### **Meta Information for Audio Datasets**
+
+#### Table 1: Meta informantion for audio datasets
+
+| **Dataset** | **Year** | **Task Name in lmms-eval** | **Split** | **Task Format** | **Evaluation Metric** | **Number of QAs** | **Feature** |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| **AIR-Bench** | 2024 | air_bench_chat \| air_bench_foundation | chat, foundation | AIF | GPT-4 Eval (chat) \| Accuracy (foundation) | 2k (chat) \| 19k (foundation) | 1. Comprhensive tasks and audio types |
+| **Alpaca Audio** | 2024 | alpaca_audio | test | AIF | GPT-4 Eval | 100 | 1. Synthetic voice |
+| **Clotho-AQA** | 2022 | clotho_aqa | test \| val | AIF | Accuracy | test_v2 (2.06k), test \| val (1.44k \| 1.05k) | 1. Audio Question Answering
2. Single word answer
3. Text based question |
+| **Common_voice** | 2023 | common_voice_15 | test | ASR | WER(โ) (align with Qwen-audio) | en (16.4k) \| fr (16.1k) \| zh (10.6k) | 1. Real people voice
2. Captioning |
+| **GigaSpeech** | 2021 | gigaspeech | test \| dev | ASR | WER(โ)| dev (6.75k) \| test (25.6k) | 1. Transciption
2. Audio book
3. YouTube
4. Podcasts |
+| **LibriSpeech** | 2015 | librispeech | dev-clean \| dev-other \| test-clean \| test-other | ASR | WER(โ)| dev-clean (~2.48k) \|
dev-other (~2.66k) \|
test-clean(~2.55k) \|
test-other (~2.70k) | 1. Transcription (audio book) |
+| **OpenHermes** | 2024 | openhermes | test | AIF | GPT-Eval | 100 | 1. Synthetic voice |
+| **MuchoMusic** | 2024 | muchomusic | test | AIF | Accuracy | 1.19k | 1. Music understanding |
+| **People_speech** | 2021 | people_speech_val | val | ASR | WER(โ)| 18.6k | 1. Real people voice
2. Captioning |
+| **Tedium v3** | 2018 | tedlium_dev_test | val | ASR | WER(โ)| 591 | 1. TED talk
2. Real people ASR
3. Captioning |
+| **VocalSound** | 2022 | vocalsound_test | test \| val | AIF | Accuracy | test (3.59k) \| val (1.86k) | 1. Vocal sound recognition
2. Non-speech |
+| **WavCaps** | 2024 | wavcaps | test | ASR | GPT-4 Eval | 1.73k | 1. Audio Captioning
2. ChatGPT-augmented captions |
+
+AIF refers to Audio Instruction Following, and ASR refers to Audio Speech Recognition.
+
+### Alignment Check for Audio Datasets
+
+#### Table 2: Alignment check for audio datasets
+
+| | | **Metric** | **Qwen2-Audio-Instruct (lmms-eval)** | **Qwen2-Audio (lmms-eval)** |
+| --- | --- | --- | --- | --- |
+| **AIR-Bench-Chat** | Speech | GPT-Eval | 7.16 | |
+| | Sound | | 6.14 | |
+| | Music | | 6.66 | |
+| | Mixed | | 5.75 | |
+| **AIR-Bench-Foundation** | Speech | Acc | 62.89 | |
+| | Sound | | 55.42 | |
+| | Music | | 56.77 | |
+| **Alpaca** | test | GPT-Eval | 51.8 | |
+| **Clotho_aqa** | test | GPT-Eval | 0.7587 | |
+| **Common_voice** | zh |WER(โ)| 15.78 | 6.7 |
+| | en | | 36.01 | 27.9 |
+| | fr | | 39.88 | 34.8 |
+| **GigaSpeech** | dev |WER(โ)| 19.45 | 14 |
+| | test | | 22.6 | 15.01 |
+| **LibriSpeech** | dev-clean |WER(โ)| 4.24 | 1.66 |
+| | dev-others | | 6.54 | 3.66 |
+| | test-clean | | 3.59 | 1.74 |
+| | test-others | | 7.46 | 3.87 |
+| **MuchoMusic** | test | Acc | 68.32 | 45.07 |
+| **OpenHermes** | test | GPT-Eval | 46.8 | |
+| **People_speech** | val |WER(โ)| 25.86 | 17.1 |
+| **Tedium** | val |WER(โ)| 10.92 | 8.29 |
+| **VocalSound** | test | Acc | 0.936 | 0.81 |
+| | val | | 0.9288 | 0.8 |
+| **WavCaps** | test | GPT-Eval | 1.73 | |
+
+
+The result might be inconsistent with the reported result as we do not have the original prompt and we have to maintain the fair environment for all the models. For the base model, we do not test on the Chat Benchmarks.
+
+Certain datasets face alignment challenge: Datasets with WER, CIDEr, BLEU as metrics cannot accurately align due to their rigid output formats. Model responses are sensitive to prompt, we will investigate more deeply in the section [Robustness of the model](#robustness-of-the-model).
+
+## Evaluation Analysis and Thinking:
+
+During our implementation, we observe several interesting phenomena that may be valuable to discuss. We believe that reflecting on these aspects deeply can help accelerate the development of truly robust audio evaluations.
+
+### Robustness of the model
+
+As we trying to align the results, our investigation revealed that the choice of chat template significantly impacts model performance, even for instruction-tuned models. This finding emerged while analyzing the Qwen2 Audio model. The original Qwen2 Audio repository uses a minimal prompt format: `"<|audio_bos|><|AUDIO|><|audio_eos|>"` .
+
+This basic format is then combined with various question prompts for different evaluation scenarios. However, this prompt format is not in an instruction format and when applying a chat template, the performance of the model may changes significantly.
+
+#### Table 3: Impact of Chat Template on Qwen-7B-Instruct's Performance
+
+| **Impact of Chat Template** | **Split** | **Metric** | **Chat Template (Off)** | **Chat Template (On)** |
+| --- | --- | --- | --- | --- |
+| **LibriSpeech** | dev-clean | WER(โ) | 2.65 | 4.24 |
+| | dev-others | | 5.36 | 6.54 |
+| | test-clean | | 2.91 | 3.59 |
+| | test-others | | 5.14 | 7.46 |
+| **People_speech** | val | WER(โ) | 21.92 | 25.86 |
+| **Tedium** | dev_test | WER(โ) | 9.56 | 10.92 |
+
+More specifically, we founds out that as shown in the above table, the influence of the chat template is very huge. We believe that these demonstrate the actual robustness of the model and signifies that current audio model may eventually not being stable enough when coping different text input. Also, it again leads us into another thinking: โIs current metrics good at evaluating a modelโs performance?
+
+### Rethinking the evaluation metrics
+
+Traditional fixed-format metrics like WER, CIDEr, and BLEU face several limitations in audio model evaluation:
+
+1. **Format Rigidity:** Fixed metrics struggle to properly assess responses that are semantically correct but differ in format from reference answers
+2. **Prompt Sensitivity:** These metrics are highly sensitive to variations in input prompts, leading to inconsistent evaluation results
+
+Due to these limitations, the scores reported in `lmms-eval` might slightly differ from those reported in original papers, highlighting the challenge of maintaining consistent evaluation standards across different frameworks.
+
+Looking ahead, model-based evaluators such as GPT-4 could offer a more flexible and robust evaluation approach. Such evaluators can better understand semantic meaning, handle diverse response formats, and provide more consistent scoring across different implementations. This shift from rigid metrics to intelligent evaluation systems may better capture the true capabilities of audio processing models.
+
+## Additional Experiments
+
+### Batch Size
+
+We perform an exploratory batch inference experiment on Qwen2-Audio with the following results:
+
+#### Table 4: Impact of batch size
+
+| | **Split** | **Metric** | **Qwen2-Audio (BS=4)** | **Qwen2-Audio (BS=1)** |
+| --- | --- | --- | --- | --- |
+| **LibriSpeech** | dev-clean | WER(โ) | 1.66 | 1.66 |
+| | dev-others | | 4.4 | 3.66 |
+| | test-clean | | 1.75 | 1.74 |
+| | test-others | | 4.06 | 3.87 |
+| **Total Time** | | | 10 mins 50 seconds | 5 min 23 seconds |
+
+As shown in the above results, the batch inference (BS=4) can significantly saves the inference time, it could lead to evaluation inconsistencies compared to single-sample processing (BS=1). This is a known issue in the `transformers` library that currently lacks a solution.
+
+### More Details and Feature Updates with `v0.3.0`
+
+1. **Supported Audio Tasks**
+ 1. [AirBench](https://github.com/OFA-Sys/AIR-Bench)
+ 2. [Alpaca Audio](https://tango2-web.github.io/)
+ 3. [Clotho-AQA](https://github.com/partha2409/AquaNet)
+ 4. [Common_voice_15](https://github.com/common-voice/common-voice)
+ 5. [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
+ 6. [LibriSpeech](https://www.openslr.org/12)
+ 7. [OpenHermes](https://huggingface.co/datasets/AudioLLMs/openhermes_instruction_test)
+ 8. [MuchoMusic](https://github.com/mulab-mir/muchomusic)
+ 9. [Peoples_speech](https://mlcommons.org/datasets/peoples-speech/)
+ 10. [Tedium v3](https://www.openslr.org/51/)
+ 11. [VocalSound](https://github.com/YuanGongND/vocalsound)
+ 12. [WavCaps](https://github.com/XinhaoMei/WavCaps)
+2. **Support Audio Models**
+
+ 1. [Qwen2-Audio](https://github.com/QwenLM/Qwen2-Audio)
+ 2. [Gemini_Audio](https://arxiv.org/abs/2312.11805)
+3. **Supporting Multi-Round Evaluation**
+ 1. [Feat][Task] Add multi-round evaluation in llava-onevision; Add MMSearch Benchmark byย [@CaraJ7](https://github.com/CaraJ7)ย inย [#277](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/277)
+4. **Regression Test**
+ 1. [Feat] add regression test and change saving logic related toย `output_path`ย byย [@Luodian](https://github.com/Luodian)ย inย [#259](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/259)
+5. **Speed-Up by loading required tasks and models.**
+ 1. [feat] remove registeration logic and adding language evaluation tasks. byย [@Luodian](https://github.com/Luodian)ย inย [#218](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/218)
+6. **LMMs-Eval Analysis Tool**
+ 1. Lite/Core-set Selection by Kaichen Zhang
+
+ https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/tools/lite
+
+ 2. LiveBench by Fanyi Pu
+
+ https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/tools/live_bench
+
+7. **SGLang Evaluation**
+ 1. [Feat] SGLang SRT commands in one go, async input for openai server byย [@kcz358](https://github.com/kcz358)ย inย [#212](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/212)
+ 2. [Fix] Fix async append result in different order issue byย [@kcz358](https://github.com/kcz358)ย inย [#244](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/244)
+
+## Contributors
+
+> Listed in order of contribution significance.
+>
+
+**Core Contributors**
+
+Pengyun Wang, Cong Pham Ba, Yingluo Li, Fanyi Pu
+
+**Release Managers**
+
+Kairui Hu, Kaichen Zhang, Bo Li
diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py
index 0646e21f..c6460e9d 100755
--- a/lmms_eval/evaluator.py
+++ b/lmms_eval/evaluator.py
@@ -492,7 +492,15 @@ def evaluate(
metrics = task.process_results(doc, [req.filtered_resps[filter_key] for req in requests])
if log_samples:
target = task.doc_to_target(doc)
- saved_doc = {key: value for key, value in doc.items() if "image" not in key}
+ saved_doc = {}
+ for key, value in doc.items():
+ # If image is not in key
+ if "image" not in key:
+ # If audio is also not the value
+ if isinstance(value, dict) and "array" in value:
+ continue
+ else:
+ saved_doc[key] = value
filtered_arguments = []
for req in requests:
# check if req.args is a list of tuples, and each item in the list is a serializable object
diff --git a/lmms_eval/evaluator_utils.py b/lmms_eval/evaluator_utils.py
index 48b5c978..58537d0f 100644
--- a/lmms_eval/evaluator_utils.py
+++ b/lmms_eval/evaluator_utils.py
@@ -346,7 +346,7 @@ def consolidate_group_results(
task_root=None,
show_group_table=False,
task_aggregation_list=None,
-) -> Tuple[dict, dict, bool, Union[None,]]:
+) -> Tuple[dict, dict, bool, Union[None, dict]]:
"""
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py
index 93103d28..ddf281f6 100644
--- a/lmms_eval/models/__init__.py
+++ b/lmms_eval/models/__init__.py
@@ -40,6 +40,7 @@
"phi3v": "Phi3v",
"qwen_vl": "Qwen_VL",
"qwen2_vl": "Qwen2_VL",
+ "qwen2_audio": "Qwen2_Audio",
"qwen_vl_api": "Qwen_VL_API",
"reka": "Reka",
"srt_api": "SRT_API",
diff --git a/lmms_eval/models/gemini_api.py b/lmms_eval/models/gemini_api.py
index 65718bc5..69f520c1 100644
--- a/lmms_eval/models/gemini_api.py
+++ b/lmms_eval/models/gemini_api.py
@@ -4,6 +4,7 @@
import time
from typing import List, Tuple
+import datasets
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from PIL import Image
@@ -25,16 +26,22 @@
eval_logger.error(f"Error importing generativeai: {str(e)}")
genai = None
+try:
+ import soundfile as sf
+except Exception as e:
+ eval_logger.warning(f"Error importing soundfile, audio generation will not work: {str(e)}")
+
@register_model("gemini_api")
class GeminiAPI(lmms):
def __init__(
self,
model_version: str = "gemini-1.5-pro",
- modality: str = "image",
+ # modality: str = "image",
timeout: int = 120,
continual_mode: bool = False,
- response_persistent_folder: str = None, # We will cache the Gemini API response in this path and use it for future requests
+ response_persistent_folder: str = "./logs/gemini_persistent_folder",
+ # We will cache the Gemini API response in this path and use it for future requests
**kwargs,
) -> None:
super().__init__()
@@ -42,12 +49,13 @@ def __init__(
self.timeout = timeout
self.model = genai.GenerativeModel(model_version)
self.continual_mode = continual_mode
- if self.continual_mode and response_persistent_folder is None:
- raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.")
- self.response_persistent_folder = response_persistent_folder
- if not os.path.exists(self.response_persistent_folder):
- os.makedirs(self.response_persistent_folder)
- self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
+ # if self.continual_mode and response_persistent_folder is None:
+ # raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.")
+ if self.continual_mode:
+ self.response_persistent_folder = response_persistent_folder
+ if not os.path.exists(self.response_persistent_folder):
+ os.makedirs(self.response_persistent_folder)
+ self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
if os.path.exists(self.response_persistent_file):
with open(self.response_persistent_file, "r") as f:
@@ -73,7 +81,7 @@ def __init__(
self.device = self.accelerator.device
- self.modality = modality
+ # self.modality = modality
self.video_pool = []
@@ -107,9 +115,17 @@ def encode_video(self, video_path):
self.video_pool.append(uploaded_obj)
return uploaded_obj
- def convert_video(self, images):
+ def encode_audio(self, audio):
+ audio_io = io.BytesIO()
+ sf.write(audio_io, audio["array"], audio["sampling_rate"], format="WAV")
+ return genai.upload_file(audio_io, mime_type="audio/wav")
+
+ def convert_modality(self, images):
for idx, img in enumerate(images):
- if self.modality == "video" and isinstance(img, str):
+ if isinstance(img, dict) and "sampling_rate" in img: # audio
+ audio = self.encode_audio(img)
+ images[idx] = audio
+ elif isinstance(img, str): # video
try:
images[idx] = self.encode_video(img)
except Exception as e:
@@ -145,7 +161,7 @@ def get_uuid(task, split, doc_id):
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
- visuals = self.convert_video(visuals)
+ visuals = self.convert_modality(visuals)
message = [contexts] + visuals
diff --git a/lmms_eval/models/model_utils/audio_processing.py b/lmms_eval/models/model_utils/audio_processing.py
new file mode 100644
index 00000000..2c27970b
--- /dev/null
+++ b/lmms_eval/models/model_utils/audio_processing.py
@@ -0,0 +1,7 @@
+import numpy as np
+from librosa import resample
+
+
+def downsample_audio(audio_array: np.ndarray, original_sr: int, target_sr: int) -> np.ndarray:
+ audio_resample_array = resample(audio_array, orig_sr=original_sr, target_sr=target_sr)
+ return audio_resample_array
diff --git a/lmms_eval/models/qwen2_audio.py b/lmms_eval/models/qwen2_audio.py
new file mode 100755
index 00000000..01f0cf22
--- /dev/null
+++ b/lmms_eval/models/qwen2_audio.py
@@ -0,0 +1,284 @@
+import base64
+from io import BytesIO
+from typing import List, Optional, Tuple, Union
+
+import decord
+import torch
+from accelerate import Accelerator, DistributedType
+from loguru import logger as eval_logger
+from tqdm import tqdm
+from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
+
+from lmms_eval import utils
+from lmms_eval.api.instance import Instance
+from lmms_eval.api.model import lmms
+from lmms_eval.api.registry import register_model
+from lmms_eval.models.model_utils.audio_processing import downsample_audio
+
+
+@register_model("qwen2_audio")
+class Qwen2_Audio(lmms):
+ """
+ Qwen2_Audio Model
+ "https://github.com/QwenLM/Qwen2-Audio"
+ """
+
+ def __init__(
+ self,
+ pretrained: str = "Qwen/Qwen2-Audio-7B", # Qwen/Qwen2-Audio-7B-Instruct
+ device: Optional[str] = "cuda",
+ device_map: Optional[str] = "cuda",
+ batch_size: Optional[Union[int, str]] = 1,
+ use_cache=True,
+ add_generation_prompt: bool = True,
+ add_system_prompt: bool = True,
+ simple_prompt: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ # Do not use kwargs for now
+ assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
+
+ accelerator = Accelerator()
+ self.add_generation_prompt = add_generation_prompt
+ self.add_system_prompt = add_system_prompt
+ # If using simple prompt, only add "<|audio_bos|><|AUDIO|><|audio_eos|>"
+ # and then prompt to align with original Qwen2 Audio
+ self.simple_prompt = simple_prompt
+ if accelerator.num_processes > 1:
+ self._device = torch.device(f"cuda:{accelerator.local_process_index}")
+ self.device_map = f"cuda:{accelerator.local_process_index}"
+ elif accelerator.num_processes == 1 and device_map == "auto":
+ self._device = torch.device(device)
+ self.device_map = device_map
+ else:
+ self._device = torch.device(f"cuda:{accelerator.local_process_index}")
+ self.device_map = f"cuda:{accelerator.local_process_index}"
+
+ self._model = Qwen2AudioForConditionalGeneration.from_pretrained(
+ pretrained,
+ torch_dtype="auto",
+ device_map=device_map,
+ ).eval()
+
+ self.processor = AutoProcessor.from_pretrained(pretrained)
+ self.processor.tokenizer.padding_side = "left"
+ self._tokenizer = self.processor.tokenizer
+
+ if not self.add_system_prompt:
+ # Overwrite chat template to exclude system prompt
+ self.processor.chat_template = (
+ "{% set audio_count = namespace(value=0) %}"
+ "{% for message in messages %}"
+ "<|im_start|>{{ message['role'] }}\n"
+ "{% if message['content'] is string %}"
+ "{{ message['content'] }}<|im_end|>\n"
+ "{% else %}"
+ "{% for content in message['content'] %}"
+ "{% if 'audio' in content or 'audio_url' in content %}"
+ "{% set audio_count.value = audio_count.value + 1 %}"
+ "Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
+ "{% elif 'text' in content %}"
+ "{{ content['text'] }}"
+ "{% endif %}"
+ "{% endfor %}"
+ "<|im_end|>\n"
+ "{% endif %}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}"
+ "<|im_start|>assistant\n"
+ "{% endif %}"
+ )
+
+ self._config = self.model.config
+ self.batch_size_per_gpu = int(batch_size)
+ self.use_cache = use_cache
+
+ if accelerator.num_processes > 1:
+ assert accelerator.distributed_type in [
+ DistributedType.FSDP,
+ DistributedType.MULTI_GPU,
+ ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
+ if accelerator.distributed_type == DistributedType.FSDP:
+ self._model = accelerator.prepare(self.model)
+ else:
+ self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
+ self.accelerator = accelerator
+ if self.accelerator.is_local_main_process:
+ eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
+ self._rank = self.accelerator.local_process_index
+ self._world_size = self.accelerator.num_processes
+ else:
+ self.model.to(self._device)
+ self._rank = 0
+ self._word_size = 1
+
+ @property
+ def config(self):
+ # return the associated transformers.AutoConfig for the given pretrained model.
+ return self._config
+
+ @property
+ def tokenizer(self):
+ return self._tokenizer
+
+ @property
+ def model(self):
+ # returns the model, unwrapping it if using Accelerate
+ if hasattr(self, "accelerator"):
+ return self.accelerator.unwrap_model(self._model)
+ else:
+ return self._model
+
+ @property
+ def eot_token_id(self):
+ return self.tokenizer.eos_token_id
+
+ @property
+ def max_length(self):
+ return self._max_length
+
+ @property
+ def batch_size(self):
+ return self.batch_size_per_gpu
+
+ @property
+ def device(self):
+ return self._device
+
+ @property
+ def rank(self):
+ return self._rank
+
+ @property
+ def world_size(self):
+ return self._world_size
+
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
+ raise NotImplementedError("Loglikelihood is not implemented for Qwen2_Audio")
+
+ def flatten(self, input):
+ new_list = []
+ for i in input:
+ for j in i:
+ new_list.append(j)
+ return new_list
+
+ def generate_until(self, requests: List[Instance]) -> List[str]:
+ res = []
+
+ def _collate(x):
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
+ # - time estimates will always be over not underestimates, which is more useful for planning
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
+ # automatic adaptive batches much much easier to implement
+ # - any OOMs will happen right away rather than near the end
+ toks = self.tokenizer.encode(x[0])
+ return -len(toks), x[0]
+
+ pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
+ # we group requests by their generation_kwargs,
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
+ # in the same batch.
+ re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
+ chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
+ for chunk in chunks:
+ contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
+ task = task[0]
+ split = split[0]
+ batched_audios = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
+ flattened_audios = self.flatten(batched_audios)
+
+ # we assume all gen kwargs in the batch are the same
+ # this is safe to assume because the `grouper` object ensures it.
+ gen_kwargs = all_gen_kwargs[0]
+
+ # Set default values for until and max_new_tokens
+ until = [self.tokenizer.decode(self.eot_token_id)]
+
+ # Update values from gen_kwargs if present
+ if "until" in gen_kwargs:
+ until = gen_kwargs.pop("until")
+ if isinstance(until, str):
+ until = [until]
+ elif not isinstance(until, list):
+ raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
+
+ # contexts = "<|audio_bos|><|AUDIO|><|audio_eos|>" + contexts
+
+ if isinstance(contexts, tuple):
+ contexts = list(contexts)
+
+ if not self.simple_prompt:
+ conversations = []
+ for idx, context in enumerate(contexts):
+ conv = [{"role": "user", "content": []}]
+ for _ in batched_audios[idx]:
+ # This placeholder is just use to make chat template work
+ # We already have the sampled audio array
+ conv[0]["content"].append({"type": "audio", "audio_url": "placeholder.wav"})
+ conv[0]["content"].append({"type": "text", "text": context})
+ conversations.append(conv)
+
+ text = [self.processor.apply_chat_template(conversation, add_generation_prompt=self.add_generation_prompt, tokenize=False) for conversation in conversations]
+ else:
+ text = ["<|audio_bos|><|AUDIO|><|audio_eos|>" + context for context in contexts]
+ audios = [downsample_audio(audio["array"], audio["sampling_rate"], self.processor.feature_extractor.sampling_rate) for audio in flattened_audios]
+
+ inputs = self.processor(text=text, audios=audios, return_tensors="pt", padding=True, sampling_rate=self.processor.feature_extractor.sampling_rate)
+
+ if self.device_map == "auto":
+ inputs = inputs.to("cuda")
+ else:
+ inputs = inputs.to(self.device)
+
+ if "max_new_tokens" not in gen_kwargs:
+ gen_kwargs["max_new_tokens"] = 256
+ if "temperature" not in gen_kwargs:
+ gen_kwargs["temperature"] = 0
+ if "top_p" not in gen_kwargs:
+ gen_kwargs["top_p"] = None
+ if "num_beams" not in gen_kwargs:
+ gen_kwargs["num_beams"] = 1
+
+ try:
+ cont = self.model.generate(
+ **inputs,
+ do_sample=True if gen_kwargs["temperature"] > 0 else False,
+ temperature=gen_kwargs["temperature"],
+ top_p=gen_kwargs["top_p"],
+ num_beams=gen_kwargs["num_beams"],
+ max_new_tokens=gen_kwargs["max_new_tokens"],
+ min_new_tokens=1,
+ use_cache=self.use_cache,
+ )
+
+ # cont = self.model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)
+
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
+ # generated_ids_trimmed = cont[:, inputs.input_ids.size(1):]
+ answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ for i, ans in enumerate(answers):
+ for term in until:
+ if len(term) > 0:
+ ans = ans.split(term)[0]
+ answers[i] = ans
+
+ except Exception as e:
+ eval_logger.debug(f"Error while generating: {e}. It is possibly due to blank audio in {contexts}")
+ answers = [""] * len(contexts)
+
+ for ans, context in zip(answers, contexts):
+ res.append(ans)
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
+ pbar.update(1)
+
+ # reorder this group of results back to original unsorted form
+ res = re_ords.get_original(res)
+
+ pbar.close()
+ return res
+
+ def generate_until_multi_round(self, requests) -> List[str]:
+ raise NotImplementedError("TODO: Implement multi-round generation")
diff --git a/lmms_eval/tasks/air_bench/_default_template_yaml b/lmms_eval/tasks/air_bench/_default_template_yaml
new file mode 100644
index 00000000..fccfc5a4
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/_default_template_yaml
@@ -0,0 +1,7 @@
+dataset_path: lmms-lab/AIR_Bench
+dataset_kwargs:
+ token: True
+
+metadata:
+ gpt_eval_model_name: gpt-4o
+ version: 0.0
diff --git a/lmms_eval/tasks/air_bench/air_bench_chat.yaml b/lmms_eval/tasks/air_bench/air_bench_chat.yaml
new file mode 100644
index 00000000..7bc686f3
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_chat.yaml
@@ -0,0 +1,6 @@
+group: air_bench_chat
+task:
+ - air_bench_chat_sound
+ - air_bench_chat_music
+ - air_bench_chat_speech
+ - air_bench_chat_mixed
diff --git a/lmms_eval/tasks/air_bench/air_bench_chat_mixed.yaml b/lmms_eval/tasks/air_bench/air_bench_chat_mixed.yaml
new file mode 100644
index 00000000..1d831978
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_chat_mixed.yaml
@@ -0,0 +1,25 @@
+task: "air_bench_chat_mixed"
+dataset_name: "Chat"
+test_split: mixed
+doc_to_target: "answer_gt"
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_chat
+
+generation_kwargs:
+ max_new_tokens: 1024
+ temperature: 0.2
+ top_p: 1.0
+ num_beams: 1
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Give a detail answer to the question in English."
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.air_bench_aggregate_results_chat
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_chat
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/air_bench_chat_music.yaml b/lmms_eval/tasks/air_bench/air_bench_chat_music.yaml
new file mode 100644
index 00000000..e4fd4a10
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_chat_music.yaml
@@ -0,0 +1,25 @@
+task: "air_bench_chat_music"
+dataset_name: "Chat"
+test_split: music
+doc_to_target: "answer_gt"
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_chat
+
+generation_kwargs:
+ max_new_tokens: 1024
+ temperature: 0.2
+ top_p: 1.0
+ num_beams: 1
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Give a detail answer to the question in English."
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.air_bench_aggregate_results_chat
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_chat
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/air_bench_chat_sound.yaml b/lmms_eval/tasks/air_bench/air_bench_chat_sound.yaml
new file mode 100644
index 00000000..fb281452
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_chat_sound.yaml
@@ -0,0 +1,25 @@
+task: "air_bench_chat_sound"
+dataset_name: "Chat"
+test_split: sound
+doc_to_target: "answer_gt"
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_chat
+
+generation_kwargs:
+ max_new_tokens: 1024
+ temperature: 0.2
+ top_p: 1.0
+ num_beams: 1
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Give a detail answer to the question in English."
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.air_bench_aggregate_results_chat
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_chat
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/air_bench_chat_speech.yaml b/lmms_eval/tasks/air_bench/air_bench_chat_speech.yaml
new file mode 100644
index 00000000..9a488ba4
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_chat_speech.yaml
@@ -0,0 +1,25 @@
+task: "air_bench_chat_speech"
+dataset_name: "Chat"
+test_split: speech
+doc_to_target: "answer_gt"
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_chat
+
+generation_kwargs:
+ max_new_tokens: 1024
+ temperature: 0.2
+ top_p: 1.0
+ num_beams: 1
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Give a detail answer to the question in English."
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.air_bench_aggregate_results_chat
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_chat
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/air_bench_foundation.yaml b/lmms_eval/tasks/air_bench/air_bench_foundation.yaml
new file mode 100644
index 00000000..bdcbe038
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_foundation.yaml
@@ -0,0 +1,5 @@
+group: air_bench_foundation
+task:
+ - air_bench_foundation_sound
+ - air_bench_foundation_music
+ - air_bench_foundation_speech
diff --git a/lmms_eval/tasks/air_bench/air_bench_foundation_music.yaml b/lmms_eval/tasks/air_bench/air_bench_foundation_music.yaml
new file mode 100644
index 00000000..566ad7a2
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_foundation_music.yaml
@@ -0,0 +1,26 @@
+task: "air_bench_foundation_music"
+dataset_name: "Foundation"
+test_split: music
+doc_to_target: !function utils.air_bench_doc_to_target_foundation
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_foundation
+doc_to_choice: !function utils.air_bench_doc_to_choice_foundation
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ do_sample: False
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer with the option's letter from the given choices directly."
+metric_list:
+ - metric: accuracy
+ aggregation: !function utils.air_bench_aggregate_results_foundation
+ higher_is_better: true
+ - metric: submission
+ aggregation: !function utils.air_bench_aggregate_results_for_submission
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_foundation
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/air_bench_foundation_sound.yaml b/lmms_eval/tasks/air_bench/air_bench_foundation_sound.yaml
new file mode 100644
index 00000000..f92160e9
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_foundation_sound.yaml
@@ -0,0 +1,26 @@
+task: "air_bench_foundation_sound"
+dataset_name: "Foundation"
+test_split: sound
+doc_to_target: !function utils.air_bench_doc_to_target_foundation
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_foundation
+doc_to_choice: !function utils.air_bench_doc_to_choice_foundation
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ do_sample: False
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer with the option's letter from the given choices directly."
+metric_list:
+ - metric: accuracy
+ aggregation: !function utils.air_bench_aggregate_results_foundation
+ higher_is_better: true
+ - metric: submission
+ aggregation: !function utils.air_bench_aggregate_results_for_submission
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_foundation
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/air_bench_foundation_speech.yaml b/lmms_eval/tasks/air_bench/air_bench_foundation_speech.yaml
new file mode 100644
index 00000000..9e1dba8f
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/air_bench_foundation_speech.yaml
@@ -0,0 +1,26 @@
+task: "air_bench_foundation_speech"
+dataset_name: "Foundation"
+test_split: speech
+doc_to_target: !function utils.air_bench_doc_to_target_foundation
+doc_to_visual: !function utils.air_bench_doc_to_audio
+doc_to_text: !function utils.air_bench_doc_to_text_foundation
+doc_to_choice: !function utils.air_bench_doc_to_choice_foundation
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ do_sample: False
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer with the option's letter from the given choices directly."
+metric_list:
+ - metric: accuracy
+ aggregation: !function utils.air_bench_aggregate_results_foundation
+ higher_is_better: true
+ - metric: submission
+ aggregation: !function utils.air_bench_aggregate_results_for_submission
+ higher_is_better: true
+
+process_results: !function utils.air_bench_process_results_foundation
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/air_bench/utils.py b/lmms_eval/tasks/air_bench/utils.py
new file mode 100644
index 00000000..32ea79c5
--- /dev/null
+++ b/lmms_eval/tasks/air_bench/utils.py
@@ -0,0 +1,267 @@
+import datetime
+import json
+import os
+import random
+import re
+import sys
+import time
+from pathlib import Path
+
+import numpy as np
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
+
+
+def air_bench_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def air_bench_doc_to_text_chat(doc, lmms_eval_specific_kwargs):
+ question = doc["question"]
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{question}{post_prompt}"
+
+
+with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ config = yaml.safe_load("".join(safe_data))
+
+# specify api type and key in .env
+GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
+API_TYPE = os.getenv("API_TYPE", "azure")
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "Authorization": f"Bearer {API_KEY}",
+ "Content-Type": "application/json",
+ }
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "api-key": API_KEY,
+ "Content-Type": "application/json",
+ }
+
+# prompt taken from the AIR-Bench repo
+eval_prompt = (
+ "You are a helpful and precise assistant for checking the quality of the answer.\n"
+ "[Detailed Audio Description]\nXAudioX\n[Question]\nXQuestionX\n"
+ "[The Start of Assistant 1s Answer]\nXAssistant1X\n[The End of Assistant 1s Answer]\n"
+ "[The Start of Assistant 2s Answer]\nXAssistant2X\n[The End of Assistant 2s Answer]\n[System]\n"
+ "We would like to request your feedback on the performance of two AI assistants in response to the user question "
+ "and audio description displayed above. AI assistants are provided with detailed audio descriptions and questions.\n"
+ "Please rate the helpfulness, relevance, accuracy, and comprehensiveness of their responses. "
+ "Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance. "
+ "Please output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. "
+ "The two scores are separated by a space. Please only output the 2 required number and no text."
+)
+
+retries = 3
+NUM_SECONDS_TO_SLEEP = 5
+
+
+def get_eval(max_tokens: int, content: str, retries: int = retries):
+ global headers
+
+ messages = [
+ {"role": "user", "content": content},
+ ]
+
+ payload = {
+ "model": GPT_EVAL_MODEL_NAME,
+ "messages": messages,
+ "temperature": 0,
+ "max_tokens": max_tokens,
+ }
+
+ for attempt in range(retries):
+ try:
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
+ response.raise_for_status()
+ response_data = response.json()
+
+ content = response_data["choices"][0]["message"]["content"].strip()
+ if content != "":
+ return content, response_data["model"]
+ break # If successful, break out of the loop
+
+ except Exception as e:
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ else: # If this was the last attempt, log and return empty
+ eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
+ return "", ""
+ return "", ""
+
+
+def air_bench_process_results_chat(doc, result):
+ path = doc["path"]
+ question = doc["question"]
+ answer_gt = doc["answer_gt"]
+ task_name = doc["task_name"]
+ dataset_name = doc["dataset_name"]
+ response = result[0]
+
+ if response == None:
+ exit(1)
+
+ if doc["meta_info"] == None:
+ print("lack meta info")
+ exit(1)
+ else:
+ meta_info = doc["meta_info"]
+
+ # Get the evaluation score 2 times: one with ourmodel as assistant 1 and the other with our model as assistant 2 to prevent position bias
+ content = eval_prompt.replace("XAudioX", meta_info).replace("XQuestionX", question).replace("XAssistant1X", answer_gt).replace("XAssistant2X", response)
+ eval_answer, model_name = get_eval(max_tokens=1024, content=content)
+ content = eval_prompt.replace("XAudioX", meta_info).replace("XQuestionX", question).replace("XAssistant1X", response).replace("XAssistant2X", answer_gt)
+ eval_answer2, model_name2 = get_eval(max_tokens=1024, content=content)
+
+ return {
+ "gpt_eval": {"eval_answer": [eval_answer, eval_answer2], "model_name": model_name},
+ }
+
+
+def air_bench_aggregate_results_chat(results):
+ score = 0
+ for result in results:
+ eval_answer = result["eval_answer"]
+ eval_answer = [eval_answer[i].strip().replace(".", "") for i in range(len(eval_answer))]
+ pattern = r"\b(?:[1-9]|10)\b"
+
+ # Find all matches
+ try:
+ matches1 = re.findall(pattern, eval_answer[0])
+ matches2 = re.findall(pattern, eval_answer[1])
+ # Get the first two occurrences
+ eval_score = float(matches1[1]) + float(matches2[0])
+ score += eval_score
+ except Exception as e:
+ eval_logger.error(f"Error parsing eval_score: {e}")
+ continue
+
+ return score / (2 * len(results))
+
+
+# Functions for Foundation tasks
+
+
+def air_bench_doc_to_text_foundation(doc, lmms_eval_specific_kwargs):
+ question = doc["question"]
+ answers = [doc["choice_a"], doc["choice_b"], doc.get("choice_c", None), doc.get("choice_d", None)]
+ question = f"{question}\nA. {answers[0]}\nB. {answers[1]}\nC. {answers[2]}\nD. {answers[3]}\n"
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{question}{post_prompt}"
+
+
+def air_bench_doc_to_target_foundation(doc):
+ return get_gt(doc)
+
+
+def air_bench_doc_to_choice_foundation(doc):
+ choices = []
+ for option in ["choice_a", "choice_b", "choice_c", "choice_d"]:
+ if doc.get(option) and doc[option].strip():
+ choices.append(option[-1].upper())
+ return choices
+
+
+def air_bench_process_results_foundation(doc, result):
+ response = result[0].strip()
+ all_choices = [choice[-1].upper() for choice in ["choice_a", "choice_b", "choice_c", "choice_d"] if doc.get(choice)]
+ pred = parse_multi_choice_response(response, all_choices) # AdaptfromMMMU
+ gt_ans = get_gt(doc)
+ score = 1.0 if pred == gt_ans else 0.0
+ submission_dict = {}
+ submission_dict = {doc.get("uniq_id", "unknown"): pred}
+ return {"accuracy": {"score": score, "task": doc["task_name"]}, "submission": submission_dict}
+
+
+def air_bench_aggregate_results_for_submission(results, args):
+ path = generate_submission_file("air_bench_test_submission.json", args)
+ with open(path, "w") as f:
+ json.dump(results, f)
+ eval_logger.info(f"Results saved to {path}.")
+
+
+def air_bench_aggregate_results_foundation(results):
+ score = 0
+ categorical_correct = {}
+ categorical_total = {}
+ for result in results:
+ score += result["score"]
+ if result["task"] not in categorical_correct.keys():
+ categorical_correct[result["task"]] = 0
+ categorical_total[result["task"]] = 0
+ categorical_correct[result["task"]] += result["score"]
+ categorical_total[result["task"]] += 1
+
+ return {"overall_accuracy": score / len(results), "categorical_accuracy": {task: categorical_correct[task] / categorical_total[task] for task in categorical_correct.keys()}}
+
+
+def parse_multi_choice_response(response, all_choices):
+ """
+ Parse the prediction from the generated response.
+ Return the predicted choice letter e.g., A, B, C, D.
+ """
+ # Clean response of unwanted characters
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
+ response = response.strip(char)
+ response = " " + response + " " # Add space to avoid partial match
+
+ candidates = []
+ # Look for choices with parentheses, e.g., (A)
+ for choice in all_choices:
+ if f"({choice})" in response:
+ candidates.append(choice)
+
+ # Look for simple choices, e.g., A, B, C
+ if len(candidates) == 0:
+ for choice in all_choices:
+ if f" {choice} " in response:
+ candidates.append(choice)
+
+ # Look for choices with periods, e.g., A., B., C.
+ if len(candidates) == 0:
+ for choice in all_choices:
+ if f"{choice}." in response:
+ candidates.append(choice)
+
+ # If no candidates, randomly choose one
+ if len(candidates) == 0:
+ pred_index = random.choice(all_choices)
+ elif len(candidates) > 1:
+ # If more than one candidate, choose the last one found
+ start_indexes = [response.rfind(f" {can} ") for can in candidates]
+ pred_index = candidates[np.argmax(start_indexes)]
+ else:
+ # If only one candidate, use it
+ pred_index = candidates[0]
+
+ return pred_index
+
+
+def get_gt(doc):
+ if doc["answer_gt"] == doc["choice_a"]:
+ return "A"
+ elif doc["answer_gt"] == doc["choice_b"]:
+ return "B"
+ elif doc["answer_gt"] == doc.get("choice_c", None):
+ return "C"
+ elif doc["answer_gt"] == doc.get("choice_d", None):
+ return "D"
diff --git a/lmms_eval/tasks/alpaca_audio/alpaca_audio.yaml b/lmms_eval/tasks/alpaca_audio/alpaca_audio.yaml
new file mode 100644
index 00000000..737ba1e2
--- /dev/null
+++ b/lmms_eval/tasks/alpaca_audio/alpaca_audio.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/alpaca_audio
+dataset_kwargs:
+ token: True
+
+task: "alpaca_audio"
+test_split: test
+doc_to_target: "answer"
+doc_to_visual: !function utils.doc_to_audio
+doc_to_text: !function utils.doc_to_text
+
+generation_kwargs:
+ max_new_tokens: 1024
+ temperature: 0.2
+ top_p: 1.0
+ num_beams: 1
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nPlease give a detail answer to the question in the audio."
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.alpaca_audio_aggregate_results
+ higher_is_better: true
+
+process_results: !function utils.alpaca_audio_process_results
+
+metadata:
+ gpt_eval_model_name: gpt-4o
+ version: 0.0
diff --git a/lmms_eval/tasks/alpaca_audio/utils.py b/lmms_eval/tasks/alpaca_audio/utils.py
new file mode 100644
index 00000000..1ba1719a
--- /dev/null
+++ b/lmms_eval/tasks/alpaca_audio/utils.py
@@ -0,0 +1,138 @@
+import datetime
+import json
+import os
+import random
+import re
+import sys
+import time
+from pathlib import Path
+
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+import lmms_eval.tasks._task_utils.file_utils as file_utils
+from lmms_eval.filters.extraction import ExtendedRegexFilter
+
+
+def doc_to_audio(doc):
+ return [doc["context"]]
+
+
+def doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{post_prompt}"
+
+
+with open(Path(__file__).parent / "alpaca_audio.yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ config = yaml.safe_load("".join(safe_data))
+
+# specify api type and key in .env
+GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
+API_TYPE = os.getenv("API_TYPE", "azure")
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "Authorization": f"Bearer {API_KEY}",
+ "Content-Type": "application/json",
+ }
+
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "api-key": API_KEY,
+ "Content-Type": "application/json",
+ }
+
+eval_prompt = """
+ [Question]
+ {question}
+
+ [Reference Answer]
+ {ground_truth}
+
+ [Model Answer]
+ {model_response}
+
+ [Task]
+ Rate the model's answer based on its alignment with the reference answer, focusing on accuracy and relevance to the reference provided. Please be critical on the details.
+ Criteria: Assess if the model's response mirrors the reference in terms of content, accuracy, and relevance.
+ Score0: The answer is completely misaligned, providing incorrect or irrelevant information compared to the reference.
+ Score1: The answer shows minimal alignment, often misunderstanding or providing irrelevant details unrelated to the reference.
+ Score2: The answer recognizes the topic but diverges significantly from the reference in accuracy or relevance.
+ Score3: The answer aligns with the reference generally but lacks detail or precise accuracy in some aspects.
+ Score4: The answer is mostly accurate and relevant, closely following the reference but could be clearer or more detailed.
+ Score5: The answer is highly accurate, detailed, and matches the reference answer perfectly, capturing its essence and detail.
+
+ Your response should be formatted as follows:
+ Explanation: (Provide a concise explanation of your rating, comparing the reference answer with the model's response. "The reference answer is [XXX], while the model's answer is [YYY]. I think ...")
+ Rating: (int)"""
+
+retries = 3
+NUM_SECONDS_TO_SLEEP = 5
+
+
+def get_eval(max_tokens: int, content: str, retries: int = retries):
+ global headers
+
+ messages = [
+ {"role": "user", "content": content},
+ ]
+
+ payload = {"model": GPT_EVAL_MODEL_NAME, "messages": messages, "temperature": 0.7, "max_tokens": max_tokens, "top_p": 0.95, "frequency_penalty": 0, "presence_penalty": 0, "stop": None}
+
+ for attempt in range(retries):
+ try:
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
+ response.raise_for_status()
+ response_data = response.json()
+
+ content = response_data["choices"][0]["message"]["content"].strip()
+ if content != "":
+ return content, response_data["model"]
+ break # If successful, break out of the loop
+
+ except Exception as e:
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ else: # If this was the last attempt, log and return empty
+ eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
+ return "", ""
+ return "", ""
+
+
+def alpaca_audio_process_results(doc, result):
+ pred = result[0]
+ ground_truth_str = doc["answer"]
+ content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str, question=doc["speech_instruction"])
+ eval_answer, model_name = get_eval(max_tokens=1024, content=content)
+ return {
+ "gpt_eval": {"eval_answer": eval_answer, "model_name": model_name},
+ }
+
+
+def alpaca_audio_aggregate_results(results):
+ score = 0
+ for result in results:
+ try:
+ eval_answer = result["eval_answer"]
+ eval_score = re.search(r"([0-5])", eval_answer).group(1)
+ eval_score = float(eval_score)
+ except Exception as e:
+ eval_logger.error(f"Error parsing eval_score: {e}")
+ eval_score = 0.0
+ score += eval_score
+
+ return score / len(results) * 20
diff --git a/lmms_eval/tasks/clotho_aqa/_default_template_yaml b/lmms_eval/tasks/clotho_aqa/_default_template_yaml
new file mode 100644
index 00000000..3d167621
--- /dev/null
+++ b/lmms_eval/tasks/clotho_aqa/_default_template_yaml
@@ -0,0 +1,10 @@
+dataset_path: lmms-lab/ClothoAQA
+dataset_kwargs:
+ token: True
+doc_to_target: "answer"
+doc_to_visual: !function utils.clotho_aqa_doc_to_audio
+doc_to_text: !function utils.clotho_aqa_doc_to_text
+
+metadata:
+ gpt_eval_model_name: gpt-4o
+ version: 0.0
diff --git a/lmms_eval/tasks/clotho_aqa/clotho_aqa.yaml b/lmms_eval/tasks/clotho_aqa/clotho_aqa.yaml
new file mode 100644
index 00000000..ed6e3799
--- /dev/null
+++ b/lmms_eval/tasks/clotho_aqa/clotho_aqa.yaml
@@ -0,0 +1,4 @@
+group: clotho_aqa
+task:
+ - clotho_aqa_val
+ - clotho_aqa_test
diff --git a/lmms_eval/tasks/clotho_aqa/clotho_aqa_test.yaml b/lmms_eval/tasks/clotho_aqa/clotho_aqa_test.yaml
new file mode 100644
index 00000000..47dfcfda
--- /dev/null
+++ b/lmms_eval/tasks/clotho_aqa/clotho_aqa_test.yaml
@@ -0,0 +1,19 @@
+task: "clotho_aqa_test"
+dataset_name: "clotho_aqa"
+test_split: clotho_aqa_test_filtered
+generation_kwargs:
+ max_new_tokens: 8
+ temperature: 0
+ do_sample: False
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer the question using a single word only. "
+metric_list:
+ - metric: exact_match
+ aggregation: mean
+ higher_is_better: true
+ ignore_case: true
+ ignore_punctuation: true
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/clotho_aqa/clotho_aqa_val.yaml b/lmms_eval/tasks/clotho_aqa/clotho_aqa_val.yaml
new file mode 100644
index 00000000..008472c6
--- /dev/null
+++ b/lmms_eval/tasks/clotho_aqa/clotho_aqa_val.yaml
@@ -0,0 +1,19 @@
+task: "clotho_aqa_val"
+dataset_name: "clotho_aqa"
+test_split: clotho_aqa_val_filtered
+generation_kwargs:
+ max_new_tokens: 8
+ temperature: 0
+ do_sample: False
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer the question using a single word only. "
+metric_list:
+ - metric: exact_match
+ aggregation: mean
+ higher_is_better: true
+ ignore_case: true
+ ignore_punctuation: true
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/clotho_aqa/clotho_asqa_test_v2.yaml b/lmms_eval/tasks/clotho_aqa/clotho_asqa_test_v2.yaml
new file mode 100644
index 00000000..d8de65ec
--- /dev/null
+++ b/lmms_eval/tasks/clotho_aqa/clotho_asqa_test_v2.yaml
@@ -0,0 +1,19 @@
+task: "clotho_asqa_test_v2"
+dataset_name: "clotho_asqa_test_v2"
+test_split: test
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ do_sample: False
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.clotho_aqa_v2_aggregate_results
+ higher_is_better: true
+
+process_results: !function utils.clotho_aqa_v2_process_results
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/clotho_aqa/utils.py b/lmms_eval/tasks/clotho_aqa/utils.py
new file mode 100644
index 00000000..b7148003
--- /dev/null
+++ b/lmms_eval/tasks/clotho_aqa/utils.py
@@ -0,0 +1,142 @@
+import datetime
+import json
+import os
+import re
+import sys
+import time
+from pathlib import Path
+
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+import lmms_eval.tasks._task_utils.file_utils as file_utils
+from lmms_eval.filters.extraction import ExtendedRegexFilter
+
+
+def clotho_aqa_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def clotho_aqa_doc_to_text(doc, lmms_eval_specific_kwargs):
+ question = doc["question"]
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{question}{post_prompt}"
+
+
+# functions for the clotho_asqa_v2 task, need to be tested later
+
+with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ config = yaml.safe_load("".join(safe_data))
+
+
+NUM_SECONDS_TO_SLEEP = 2
+GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
+API_TYPE = os.getenv("API_TYPE", "azure")
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "Authorization": f"Bearer {API_KEY}",
+ "Content-Type": "application/json",
+ }
+
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "api-key": API_KEY,
+ "Content-Type": "application/json",
+ }
+
+eval_prompt = """
+ [Question]
+ {question}
+
+ [Reference Answer]
+ {ground_truth}
+
+ [Model Answer]
+ {model_response}
+
+ [Task]
+ Rate the model's answer based on its alignment with the reference answer, focusing on accuracy and relevance to the reference provided. Please be critical on the details.
+ Criteria: Assess if the model's response mirrors the reference in terms of content, accuracy, and relevance.
+ Score0: The answer is completely misaligned, providing incorrect or irrelevant information compared to the reference.
+ Score1: The answer shows minimal alignment, often misunderstanding or providing irrelevant details unrelated to the reference.
+ Score2: The answer recognizes the topic but diverges significantly from the reference in accuracy or relevance.
+ Score3: The answer aligns with the reference generally but lacks detail or precise accuracy in some aspects.
+ Score4: The answer is mostly accurate and relevant, closely following the reference but could be clearer or more detailed.
+ Score5: The answer is highly accurate, detailed, and matches the reference answer perfectly, capturing its essence and detail.
+
+ Your response should be formatted as follows:
+ Explanation: (Provide a concise explanation of your rating, comparing the reference answer with the model's response. "The reference answer is [XXX], while the model's answer is [YYY]. I think ...")
+ Rating: (int)"""
+
+
+retries = 3
+NUM_SECONDS_TO_SLEEP = 5
+
+
+def get_eval(max_tokens: int, content: str, retries: int = retries):
+ global headers
+
+ messages = [
+ {"role": "user", "content": content},
+ ]
+
+ payload = {"model": GPT_EVAL_MODEL_NAME, "messages": messages, "temperature": 0.7, "max_tokens": max_tokens, "top_p": 0.95, "frequency_penalty": 0, "presence_penalty": 0, "stop": None}
+
+ for attempt in range(retries):
+ try:
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
+ response.raise_for_status()
+ response_data = response.json()
+
+ content = response_data["choices"][0]["message"]["content"].strip()
+ if content != "":
+ return content, response_data["model"]
+ break # If successful, break out of the loop
+
+ except Exception as e:
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ else: # If this was the last attempt, log and return empty
+ eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
+ return "", ""
+ return "", ""
+
+
+def clotho_aqa_v2_process_results(doc, result):
+ pred = result[0]
+ ground_truth_str = doc["answer"]
+ content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str, question=doc["question"])
+ eval_answer, model_name = get_eval(max_tokens=1024, content=content)
+ return {
+ "gpt_eval": {"eval_answer": eval_answer, "model_name": model_name},
+ }
+
+
+def clotho_aqa_v2_aggregate_results(results):
+ score = 0
+ for result in results:
+ eval_answer = result["eval_answer"]
+ eval_score = re.search(r"([0-5])", eval_answer).group(1)
+ try:
+ eval_score = float(eval_score)
+ except Exception as e:
+ eval_logger.error(f"Error parsing eval_score: {e}")
+ eval_score = 0.0
+ score += eval_score
+
+ return score / len(results) * 20
diff --git a/lmms_eval/tasks/common_voice_15/common_voice_15.yaml b/lmms_eval/tasks/common_voice_15/common_voice_15.yaml
new file mode 100644
index 00000000..4ae38dfd
--- /dev/null
+++ b/lmms_eval/tasks/common_voice_15/common_voice_15.yaml
@@ -0,0 +1,5 @@
+group: common_voice_15
+task:
+- common_voice_15_zh-CN
+- common_voice_15_en
+- common_voice_15_fr
\ No newline at end of file
diff --git a/lmms_eval/tasks/common_voice_15/common_voice_15_en.yaml b/lmms_eval/tasks/common_voice_15/common_voice_15_en.yaml
new file mode 100644
index 00000000..b8e2fb36
--- /dev/null
+++ b/lmms_eval/tasks/common_voice_15/common_voice_15_en.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/common_voice_15
+dataset_kwargs:
+ token: True
+task : "common_voice_15_en"
+test_split: test
+dataset_name: en
+output_type: generate_until
+doc_to_visual: !function utils.common_voice_15_doc_to_audio
+doc_to_text: !function utils.common_voice_15_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.common_voice_15_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.common_voice_15_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/common_voice_15/common_voice_15_fr.yaml b/lmms_eval/tasks/common_voice_15/common_voice_15_fr.yaml
new file mode 100644
index 00000000..1254c2bc
--- /dev/null
+++ b/lmms_eval/tasks/common_voice_15/common_voice_15_fr.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/common_voice_15
+dataset_kwargs:
+ token: True
+task : "common_voice_15_fr"
+test_split: test
+dataset_name: fr
+output_type: generate_until
+doc_to_visual: !function utils.common_voice_15_doc_to_audio
+doc_to_text: !function utils.common_voice_15_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.common_voice_15_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.common_voice_15_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|fr|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/common_voice_15/common_voice_15_zh-CN.yaml b/lmms_eval/tasks/common_voice_15/common_voice_15_zh-CN.yaml
new file mode 100644
index 00000000..92dba1d5
--- /dev/null
+++ b/lmms_eval/tasks/common_voice_15/common_voice_15_zh-CN.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/common_voice_15
+dataset_kwargs:
+ token: True
+task : "common_voice_15_zh-CN"
+test_split: test
+dataset_name: zh-CN
+output_type: generate_until
+doc_to_visual: !function utils.common_voice_15_doc_to_audio
+doc_to_text: !function utils.common_voice_15_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.common_voice_15_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.common_voice_15_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|zh|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/common_voice_15/utils.py b/lmms_eval/tasks/common_voice_15/utils.py
new file mode 100644
index 00000000..312bfa17
--- /dev/null
+++ b/lmms_eval/tasks/common_voice_15/utils.py
@@ -0,0 +1,182 @@
+import os
+import re
+import unicodedata
+
+import editdistance as ed
+import zhconv
+
+from lmms_eval.tasks.librispeech.cn_tn import TextNorm
+from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer
+from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer
+
+# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'.
+english_normalizer = EnglishTextNormalizer()
+chinese_normalizer = TextNorm(
+ to_banjiao=False,
+ to_upper=False,
+ to_lower=False,
+ remove_fillers=False,
+ remove_erhua=False,
+ check_chars=False,
+ remove_space=False,
+ cc_mode="",
+)
+basic_normalizer = BasicTextNormalizer()
+
+dir_name = os.path.dirname(os.path.abspath(__file__))
+
+
+def common_voice_15_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def common_voice_15_doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"
+
+
+def common_voice_15_process_result(doc, result):
+ pred = result[0] if len(result) > 0 else ""
+
+ gt = doc["gt"]
+ source = doc["source"]
+ task = doc["task"]
+
+ data_dict = {"gt": gt, "pred": pred, "source": source, "task": task}
+
+ return {"wer": data_dict}
+
+
+PUNCS = "!,.?;:"
+
+
+def remove_sp(text, language):
+ gt = re.sub(r"<\|.*?\|>", " ", text)
+ gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space.
+ gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
+ gt = gt.lstrip(" ")
+ if language == "zh":
+ gt = re.sub(rf"\s+", r"", gt)
+ return gt
+
+
+class EvaluationTokenizer(object):
+ """A generic evaluation-time tokenizer, which leverages built-in tokenizers
+ in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
+ lowercasing, punctuation removal and character tokenization, which are
+ applied after sacreBLEU tokenization.
+
+ Args:
+ tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
+ lowercase (bool): lowercase the text.
+ punctuation_removal (bool): remove punctuation (based on unicode
+ category) from text.
+ character_tokenization (bool): tokenize the text to characters.
+ """
+
+ SPACE = chr(32)
+ SPACE_ESCAPE = chr(9601)
+ # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
+
+ def __init__(
+ self,
+ tokenizer_type: str = "13a",
+ lowercase: bool = False,
+ punctuation_removal: bool = False,
+ character_tokenization: bool = False,
+ ):
+ from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
+ from sacrebleu.tokenizers.tokenizer_char import TokenizerChar
+ from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International
+ from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab
+ from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer
+ from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh
+
+ TOKENIZERS = {
+ "none": NoneTokenizer,
+ "13a": Tokenizer13a,
+ "intl": TokenizerV14International,
+ "zh": TokenizerZh,
+ "ja-mecab": TokenizerJaMecab,
+ "char": TokenizerChar,
+ }
+
+ assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
+ self.lowercase = lowercase
+ self.punctuation_removal = punctuation_removal
+ self.character_tokenization = character_tokenization
+ self.tokenizer = TOKENIZERS[tokenizer_type]
+ # self.tokenizer = tokenizer_none
+
+ @classmethod
+ def remove_punctuation(cls, sent: str):
+ """Remove punctuation based on Unicode category."""
+ return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t))
+
+ def tokenize(self, sent: str):
+ tokenized = self.tokenizer()(sent)
+
+ if self.punctuation_removal:
+ tokenized = self.remove_punctuation(tokenized)
+
+ if self.character_tokenization:
+ tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
+
+ if self.lowercase:
+ tokenized = tokenized.lower()
+
+ return tokenized
+
+
+def compute_wer(refs, hyps, language):
+ distance = 0
+ ref_length = 0
+ tokenizer = EvaluationTokenizer(
+ tokenizer_type="none",
+ lowercase=True,
+ punctuation_removal=True,
+ character_tokenization=False,
+ )
+ for i in range(len(refs)):
+ ref = refs[i]
+ pred = hyps[i]
+ if language in ["yue"]:
+ ref = zhconv.convert(ref, "zh-cn")
+ pred = zhconv.convert(pred, "zh-cn")
+ if language in ["en"]:
+ ref = english_normalizer(ref)
+ pred = english_normalizer(pred)
+ if language in ["zh"]:
+ ref = chinese_normalizer(ref)
+ pred = chinese_normalizer(pred)
+ else:
+ ref = basic_normalizer(ref)
+ pred = basic_normalizer(pred)
+ ref_items = tokenizer.tokenize(ref).split()
+ pred_items = tokenizer.tokenize(pred).split()
+ if language in ["zh", "yue"]:
+ ref_items = [x for x in "".join(ref_items)]
+ pred_items = [x for x in "".join(pred_items)]
+ if i == 0:
+ print(f"ref: {ref}")
+ print(f"pred: {pred}")
+ print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
+ print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
+ distance += ed.eval(ref_items, pred_items)
+ ref_length += len(ref_items)
+ return distance / ref_length
+
+
+def common_voice_15_wer(results, args):
+ refs, hyps = [], []
+ for result in results:
+ lan = result["task"][4:]
+ gt = result["gt"]
+ response = result["pred"]
+ gt = remove_sp(gt, lan)
+ response = remove_sp(response, lan)
+ refs.append(gt)
+ hyps.append(response)
+ wer = compute_wer(refs, hyps, lan)
+ return wer * 100
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech.yaml
new file mode 100644
index 00000000..55d33b41
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech.yaml
@@ -0,0 +1,4 @@
+group: gigaspeech
+task:
+ - gigaspeech_dev
+ - gigaspeech_test
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_dev.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_dev.yaml
new file mode 100644
index 00000000..d4ed3346
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_dev.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: dev
+dataset_kwargs:
+ token: True
+task : "gigaspeech_dev"
+test_split: validation
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_l_dev.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_l_dev.yaml
new file mode 100644
index 00000000..97fa7f4e
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_l_dev.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: l
+dataset_kwargs:
+ token: True
+task : "gigaspeech_l_dev"
+test_split: validation
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_l_test.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_l_test.yaml
new file mode 100644
index 00000000..75bb14f0
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_l_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: l
+dataset_kwargs:
+ token: True
+task : "gigaspeech_l_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_m_dev.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_m_dev.yaml
new file mode 100644
index 00000000..29f0bbfd
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_m_dev.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: m
+dataset_kwargs:
+ token: True
+task : "gigaspeech_m_dev"
+test_split: validation
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_m_test.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_m_test.yaml
new file mode 100644
index 00000000..9e9dfbcb
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_m_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: m
+dataset_kwargs:
+ token: True
+task : "gigaspeech_m_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_s_dev.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_s_dev.yaml
new file mode 100644
index 00000000..cbdf8fe9
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_s_dev.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: s
+dataset_kwargs:
+ token: True
+task : "gigaspeech_s_dev"
+test_split: validation
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_s_test.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_s_test.yaml
new file mode 100644
index 00000000..d878a0cc
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_s_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: s
+dataset_kwargs:
+ token: True
+task : "gigaspeech_s_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_test.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_test.yaml
new file mode 100644
index 00000000..d6f98db1
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: test
+dataset_kwargs:
+ token: True
+task : "gigaspeech_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_xl_dev.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_xl_dev.yaml
new file mode 100644
index 00000000..86d39b69
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_xl_dev.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: xl
+dataset_kwargs:
+ token: True
+task : "gigaspeech_xl_dev"
+test_split: validation
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_xl_test.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_xl_test.yaml
new file mode 100644
index 00000000..67e1f463
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_xl_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: xl
+dataset_kwargs:
+ token: True
+task : "gigaspeech_xl_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_xs_dev.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_xs_dev.yaml
new file mode 100644
index 00000000..558a154e
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_xs_dev.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: xs
+dataset_kwargs:
+ token: True
+task : "gigaspeech_xs_dev"
+test_split: validation
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/gigaspeech_xs_test.yaml b/lmms_eval/tasks/gigaspeech/gigaspeech_xs_test.yaml
new file mode 100644
index 00000000..744a0260
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/gigaspeech_xs_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/gigaspeech
+dataset_name: xs
+dataset_kwargs:
+ token: True
+task : "gigaspeech_xs_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.gigaspeech_doc_to_audio
+doc_to_text: !function utils.gigaspeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.gigaspeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.gigaspeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/utils.py b/lmms_eval/tasks/gigaspeech/utils.py
new file mode 100755
index 00000000..c45a9d61
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/utils.py
@@ -0,0 +1,162 @@
+import os
+import re
+import unicodedata
+
+import editdistance as ed # TODO: new package
+
+from lmms_eval.tasks.gigaspeech.whisper_normalizer.basic import BasicTextNormalizer
+from lmms_eval.tasks.gigaspeech.whisper_normalizer.english import EnglishTextNormalizer
+
+# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'.
+english_normalizer = EnglishTextNormalizer()
+
+basic_normalizer = BasicTextNormalizer()
+
+dir_name = os.path.dirname(os.path.abspath(__file__))
+
+
+def gigaspeech_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def gigaspeech_doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"
+
+
+def gigaspeech_process_result(doc, result):
+ pred = result[0] if len(result) > 0 else ""
+ gt = doc["gt"]
+ data_dict = {"gt": gt, "pred": pred}
+
+ return {"wer": data_dict}
+
+
+def gigaspeech_xl_process_result(doc, result):
+ pred = result[0] if len(result) > 0 else ""
+ gt = doc["text"]
+ data_dict = {"gt": gt, "pred": pred}
+
+ return {"wer": data_dict}
+
+
+PUNCS = "!,.?;:"
+
+
+def remove_sp(text):
+ gt = re.sub(r"<\|.*?\|>", " ", text)
+ gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space.
+ gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
+ gt = gt.lstrip(" ")
+ return gt
+
+
+class EvaluationTokenizer(object):
+ """A generic evaluation-time tokenizer, which leverages built-in tokenizers
+ in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
+ lowercasing, punctuation removal and character tokenization, which are
+ applied after sacreBLEU tokenization.
+
+ Args:
+ tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
+ lowercase (bool): lowercase the text.
+ punctuation_removal (bool): remove punctuation (based on unicode
+ category) from text.
+ character_tokenization (bool): tokenize the text to characters.
+ """
+
+ SPACE = chr(32)
+ SPACE_ESCAPE = chr(9601)
+ # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
+
+ def __init__(
+ self,
+ tokenizer_type: str = "13a",
+ lowercase: bool = False,
+ punctuation_removal: bool = False,
+ character_tokenization: bool = False,
+ ):
+ # from sacrebleu.tokenizers import TOKENIZERS
+ # from sacrebleu.tokenizers import tokenizer_none
+ from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
+ from sacrebleu.tokenizers.tokenizer_char import TokenizerChar
+ from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International
+ from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab
+ from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer
+ from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh
+
+ TOKENIZERS = {
+ "none": NoneTokenizer,
+ "13a": Tokenizer13a,
+ "intl": TokenizerV14International,
+ "zh": TokenizerZh,
+ "ja-mecab": TokenizerJaMecab,
+ "char": TokenizerChar,
+ }
+
+ assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
+ self.lowercase = lowercase
+ self.punctuation_removal = punctuation_removal
+ self.character_tokenization = character_tokenization
+ self.tokenizer = TOKENIZERS[tokenizer_type]
+ # self.tokenizer = tokenizer_none
+
+ @classmethod
+ def remove_punctuation(cls, sent: str):
+ """Remove punctuation based on Unicode category."""
+ return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t))
+
+ def tokenize(self, sent: str):
+ tokenized = self.tokenizer()(sent)
+
+ if self.punctuation_removal:
+ tokenized = self.remove_punctuation(tokenized)
+
+ if self.character_tokenization:
+ tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
+
+ if self.lowercase:
+ tokenized = tokenized.lower()
+
+ return tokenized
+
+
+def compute_wer(refs, hyps):
+ distance = 0
+ ref_length = 0
+ tokenizer = EvaluationTokenizer(
+ tokenizer_type="none",
+ lowercase=True,
+ punctuation_removal=True,
+ character_tokenization=False,
+ )
+ for i in range(len(refs)):
+ ref = refs[i]
+ pred = hyps[i]
+ ref = english_normalizer(ref)
+ pred = english_normalizer(pred)
+ ref_items = tokenizer.tokenize(ref).split()
+ pred_items = tokenizer.tokenize(pred).split()
+ if i == 0:
+ print(f"ref: {ref}")
+ print(f"pred: {pred}")
+ print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
+ print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
+ distance += ed.eval(ref_items, pred_items)
+ ref_length += len(ref_items)
+ return distance / ref_length
+
+
+def gigaspeech_wer(results, args):
+ refs, hyps = [], []
+ for result in results:
+ gt = result["gt"]
+ response = result["pred"]
+ gt = remove_sp(gt)
+ response = remove_sp(response)
+ refs.append(gt)
+ hyps.append(response)
+ wer = compute_wer(refs, hyps)
+ # print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}")
+ return wer * 100
diff --git a/lmms_eval/tasks/gigaspeech/whisper_normalizer/basic.py b/lmms_eval/tasks/gigaspeech/whisper_normalizer/basic.py
new file mode 100644
index 00000000..00a54dcc
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/whisper_normalizer/basic.py
@@ -0,0 +1,58 @@
+import re
+import unicodedata
+
+import regex
+
+# non-ASCII letters that are not separated by "NFKD" normalization
+ADDITIONAL_DIACRITICS = {
+ "ล": "oe",
+ "ล": "OE",
+ "รธ": "o",
+ "ร": "O",
+ "รฆ": "ae",
+ "ร": "AE",
+ "ร": "ss",
+ "แบ": "SS",
+ "ฤ": "d",
+ "ฤ": "D",
+ "รฐ": "d",
+ "ร": "D",
+ "รพ": "th",
+ "ร": "th",
+ "ล": "l",
+ "ล": "L",
+}
+
+
+def remove_symbols_and_diacritics(s: str, keep=""):
+ """
+ Replace any other markers, symbols, and punctuations with a space,
+ and drop any diacritics (category 'Mn' and some manual mappings)
+ """
+ return "".join(c if c in keep else ADDITIONAL_DIACRITICS[c] if c in ADDITIONAL_DIACRITICS else "" if unicodedata.category(c) == "Mn" else " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKD", s))
+
+
+def remove_symbols(s: str):
+ """
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
+ """
+ return "".join(" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s))
+
+
+class BasicTextNormalizer:
+ def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
+ self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
+ self.split_letters = split_letters
+
+ def __call__(self, s: str):
+ s = s.lower()
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
+ s = self.clean(s).lower()
+
+ if self.split_letters:
+ s = " ".join(regex.findall(r"\X", s, regex.U))
+
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
+
+ return s
diff --git a/lmms_eval/tasks/gigaspeech/whisper_normalizer/english.json b/lmms_eval/tasks/gigaspeech/whisper_normalizer/english.json
new file mode 100644
index 00000000..566e4812
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/whisper_normalizer/english.json
@@ -0,0 +1,1741 @@
+{
+ "accessorise": "accessorize",
+ "accessorised": "accessorized",
+ "accessorises": "accessorizes",
+ "accessorising": "accessorizing",
+ "acclimatisation": "acclimatization",
+ "acclimatise": "acclimatize",
+ "acclimatised": "acclimatized",
+ "acclimatises": "acclimatizes",
+ "acclimatising": "acclimatizing",
+ "accoutrements": "accouterments",
+ "aeon": "eon",
+ "aeons": "eons",
+ "aerogramme": "aerogram",
+ "aerogrammes": "aerograms",
+ "aeroplane": "airplane",
+ "aeroplanes": "airplanes",
+ "aesthete": "esthete",
+ "aesthetes": "esthetes",
+ "aesthetic": "esthetic",
+ "aesthetically": "esthetically",
+ "aesthetics": "esthetics",
+ "aetiology": "etiology",
+ "ageing": "aging",
+ "aggrandisement": "aggrandizement",
+ "agonise": "agonize",
+ "agonised": "agonized",
+ "agonises": "agonizes",
+ "agonising": "agonizing",
+ "agonisingly": "agonizingly",
+ "almanack": "almanac",
+ "almanacks": "almanacs",
+ "aluminium": "aluminum",
+ "amortisable": "amortizable",
+ "amortisation": "amortization",
+ "amortisations": "amortizations",
+ "amortise": "amortize",
+ "amortised": "amortized",
+ "amortises": "amortizes",
+ "amortising": "amortizing",
+ "amphitheatre": "amphitheater",
+ "amphitheatres": "amphitheaters",
+ "anaemia": "anemia",
+ "anaemic": "anemic",
+ "anaesthesia": "anesthesia",
+ "anaesthetic": "anesthetic",
+ "anaesthetics": "anesthetics",
+ "anaesthetise": "anesthetize",
+ "anaesthetised": "anesthetized",
+ "anaesthetises": "anesthetizes",
+ "anaesthetising": "anesthetizing",
+ "anaesthetist": "anesthetist",
+ "anaesthetists": "anesthetists",
+ "anaesthetize": "anesthetize",
+ "anaesthetized": "anesthetized",
+ "anaesthetizes": "anesthetizes",
+ "anaesthetizing": "anesthetizing",
+ "analogue": "analog",
+ "analogues": "analogs",
+ "analyse": "analyze",
+ "analysed": "analyzed",
+ "analyses": "analyzes",
+ "analysing": "analyzing",
+ "anglicise": "anglicize",
+ "anglicised": "anglicized",
+ "anglicises": "anglicizes",
+ "anglicising": "anglicizing",
+ "annualised": "annualized",
+ "antagonise": "antagonize",
+ "antagonised": "antagonized",
+ "antagonises": "antagonizes",
+ "antagonising": "antagonizing",
+ "apologise": "apologize",
+ "apologised": "apologized",
+ "apologises": "apologizes",
+ "apologising": "apologizing",
+ "appal": "appall",
+ "appals": "appalls",
+ "appetiser": "appetizer",
+ "appetisers": "appetizers",
+ "appetising": "appetizing",
+ "appetisingly": "appetizingly",
+ "arbour": "arbor",
+ "arbours": "arbors",
+ "archeological": "archaeological",
+ "archaeologically": "archeologically",
+ "archaeologist": "archeologist",
+ "archaeologists": "archeologists",
+ "archaeology": "archeology",
+ "ardour": "ardor",
+ "armour": "armor",
+ "armoured": "armored",
+ "armourer": "armorer",
+ "armourers": "armorers",
+ "armouries": "armories",
+ "armoury": "armory",
+ "artefact": "artifact",
+ "artefacts": "artifacts",
+ "authorise": "authorize",
+ "authorised": "authorized",
+ "authorises": "authorizes",
+ "authorising": "authorizing",
+ "axe": "ax",
+ "backpedalled": "backpedaled",
+ "backpedalling": "backpedaling",
+ "bannister": "banister",
+ "bannisters": "banisters",
+ "baptise": "baptize",
+ "baptised": "baptized",
+ "baptises": "baptizes",
+ "baptising": "baptizing",
+ "bastardise": "bastardize",
+ "bastardised": "bastardized",
+ "bastardises": "bastardizes",
+ "bastardising": "bastardizing",
+ "battleax": "battleaxe",
+ "baulk": "balk",
+ "baulked": "balked",
+ "baulking": "balking",
+ "baulks": "balks",
+ "bedevilled": "bedeviled",
+ "bedevilling": "bedeviling",
+ "behaviour": "behavior",
+ "behavioural": "behavioral",
+ "behaviourism": "behaviorism",
+ "behaviourist": "behaviorist",
+ "behaviourists": "behaviorists",
+ "behaviours": "behaviors",
+ "behove": "behoove",
+ "behoved": "behooved",
+ "behoves": "behooves",
+ "bejewelled": "bejeweled",
+ "belabour": "belabor",
+ "belaboured": "belabored",
+ "belabouring": "belaboring",
+ "belabours": "belabors",
+ "bevelled": "beveled",
+ "bevvies": "bevies",
+ "bevvy": "bevy",
+ "biassed": "biased",
+ "biassing": "biasing",
+ "bingeing": "binging",
+ "bougainvillaea": "bougainvillea",
+ "bougainvillaeas": "bougainvilleas",
+ "bowdlerise": "bowdlerize",
+ "bowdlerised": "bowdlerized",
+ "bowdlerises": "bowdlerizes",
+ "bowdlerising": "bowdlerizing",
+ "breathalyse": "breathalyze",
+ "breathalysed": "breathalyzed",
+ "breathalyser": "breathalyzer",
+ "breathalysers": "breathalyzers",
+ "breathalyses": "breathalyzes",
+ "breathalysing": "breathalyzing",
+ "brutalise": "brutalize",
+ "brutalised": "brutalized",
+ "brutalises": "brutalizes",
+ "brutalising": "brutalizing",
+ "busses": "buses",
+ "bussing": "busing",
+ "caesarean": "cesarean",
+ "caesareans": "cesareans",
+ "calibre": "caliber",
+ "calibres": "calibers",
+ "calliper": "caliper",
+ "callipers": "calipers",
+ "callisthenics": "calisthenics",
+ "canalise": "canalize",
+ "canalised": "canalized",
+ "canalises": "canalizes",
+ "canalising": "canalizing",
+ "cancelation": "cancellation",
+ "cancelations": "cancellations",
+ "cancelled": "canceled",
+ "cancelling": "canceling",
+ "candour": "candor",
+ "cannibalise": "cannibalize",
+ "cannibalised": "cannibalized",
+ "cannibalises": "cannibalizes",
+ "cannibalising": "cannibalizing",
+ "canonise": "canonize",
+ "canonised": "canonized",
+ "canonises": "canonizes",
+ "canonising": "canonizing",
+ "capitalise": "capitalize",
+ "capitalised": "capitalized",
+ "capitalises": "capitalizes",
+ "capitalising": "capitalizing",
+ "caramelise": "caramelize",
+ "caramelised": "caramelized",
+ "caramelises": "caramelizes",
+ "caramelising": "caramelizing",
+ "carbonise": "carbonize",
+ "carbonised": "carbonized",
+ "carbonises": "carbonizes",
+ "carbonising": "carbonizing",
+ "carolled": "caroled",
+ "carolling": "caroling",
+ "catalogue": "catalog",
+ "catalogued": "cataloged",
+ "catalogues": "catalogs",
+ "cataloguing": "cataloging",
+ "catalyse": "catalyze",
+ "catalysed": "catalyzed",
+ "catalyses": "catalyzes",
+ "catalysing": "catalyzing",
+ "categorise": "categorize",
+ "categorised": "categorized",
+ "categorises": "categorizes",
+ "categorising": "categorizing",
+ "cauterise": "cauterize",
+ "cauterised": "cauterized",
+ "cauterises": "cauterizes",
+ "cauterising": "cauterizing",
+ "cavilled": "caviled",
+ "cavilling": "caviling",
+ "centigramme": "centigram",
+ "centigrammes": "centigrams",
+ "centilitre": "centiliter",
+ "centilitres": "centiliters",
+ "centimetre": "centimeter",
+ "centimetres": "centimeters",
+ "centralise": "centralize",
+ "centralised": "centralized",
+ "centralises": "centralizes",
+ "centralising": "centralizing",
+ "centre": "center",
+ "centred": "centered",
+ "centrefold": "centerfold",
+ "centrefolds": "centerfolds",
+ "centrepiece": "centerpiece",
+ "centrepieces": "centerpieces",
+ "centres": "centers",
+ "channelled": "channeled",
+ "channelling": "channeling",
+ "characterise": "characterize",
+ "characterised": "characterized",
+ "characterises": "characterizes",
+ "characterising": "characterizing",
+ "cheque": "check",
+ "chequebook": "checkbook",
+ "chequebooks": "checkbooks",
+ "chequered": "checkered",
+ "cheques": "checks",
+ "chilli": "chili",
+ "chimaera": "chimera",
+ "chimaeras": "chimeras",
+ "chiselled": "chiseled",
+ "chiselling": "chiseling",
+ "circularise": "circularize",
+ "circularised": "circularized",
+ "circularises": "circularizes",
+ "circularising": "circularizing",
+ "civilise": "civilize",
+ "civilised": "civilized",
+ "civilises": "civilizes",
+ "civilising": "civilizing",
+ "clamour": "clamor",
+ "clamoured": "clamored",
+ "clamouring": "clamoring",
+ "clamours": "clamors",
+ "clangour": "clangor",
+ "clarinettist": "clarinetist",
+ "clarinettists": "clarinetists",
+ "collectivise": "collectivize",
+ "collectivised": "collectivized",
+ "collectivises": "collectivizes",
+ "collectivising": "collectivizing",
+ "colonisation": "colonization",
+ "colonise": "colonize",
+ "colonised": "colonized",
+ "coloniser": "colonizer",
+ "colonisers": "colonizers",
+ "colonises": "colonizes",
+ "colonising": "colonizing",
+ "colour": "color",
+ "colourant": "colorant",
+ "colourants": "colorants",
+ "coloured": "colored",
+ "coloureds": "coloreds",
+ "colourful": "colorful",
+ "colourfully": "colorfully",
+ "colouring": "coloring",
+ "colourize": "colorize",
+ "colourized": "colorized",
+ "colourizes": "colorizes",
+ "colourizing": "colorizing",
+ "colourless": "colorless",
+ "colours": "colors",
+ "commercialise": "commercialize",
+ "commercialised": "commercialized",
+ "commercialises": "commercializes",
+ "commercialising": "commercializing",
+ "compartmentalise": "compartmentalize",
+ "compartmentalised": "compartmentalized",
+ "compartmentalises": "compartmentalizes",
+ "compartmentalising": "compartmentalizing",
+ "computerise": "computerize",
+ "computerised": "computerized",
+ "computerises": "computerizes",
+ "computerising": "computerizing",
+ "conceptualise": "conceptualize",
+ "conceptualised": "conceptualized",
+ "conceptualises": "conceptualizes",
+ "conceptualising": "conceptualizing",
+ "connexion": "connection",
+ "connexions": "connections",
+ "contextualise": "contextualize",
+ "contextualised": "contextualized",
+ "contextualises": "contextualizes",
+ "contextualising": "contextualizing",
+ "cosier": "cozier",
+ "cosies": "cozies",
+ "cosiest": "coziest",
+ "cosily": "cozily",
+ "cosiness": "coziness",
+ "cosy": "cozy",
+ "councillor": "councilor",
+ "councillors": "councilors",
+ "counselled": "counseled",
+ "counselling": "counseling",
+ "counsellor": "counselor",
+ "counsellors": "counselors",
+ "crenelated": "crenellated",
+ "criminalise": "criminalize",
+ "criminalised": "criminalized",
+ "criminalises": "criminalizes",
+ "criminalising": "criminalizing",
+ "criticise": "criticize",
+ "criticised": "criticized",
+ "criticises": "criticizes",
+ "criticising": "criticizing",
+ "crueller": "crueler",
+ "cruellest": "cruelest",
+ "crystallisation": "crystallization",
+ "crystallise": "crystallize",
+ "crystallised": "crystallized",
+ "crystallises": "crystallizes",
+ "crystallising": "crystallizing",
+ "cudgelled": "cudgeled",
+ "cudgelling": "cudgeling",
+ "customise": "customize",
+ "customised": "customized",
+ "customises": "customizes",
+ "customising": "customizing",
+ "cypher": "cipher",
+ "cyphers": "ciphers",
+ "decentralisation": "decentralization",
+ "decentralise": "decentralize",
+ "decentralised": "decentralized",
+ "decentralises": "decentralizes",
+ "decentralising": "decentralizing",
+ "decriminalisation": "decriminalization",
+ "decriminalise": "decriminalize",
+ "decriminalised": "decriminalized",
+ "decriminalises": "decriminalizes",
+ "decriminalising": "decriminalizing",
+ "defence": "defense",
+ "defenceless": "defenseless",
+ "defences": "defenses",
+ "dehumanisation": "dehumanization",
+ "dehumanise": "dehumanize",
+ "dehumanised": "dehumanized",
+ "dehumanises": "dehumanizes",
+ "dehumanising": "dehumanizing",
+ "demeanour": "demeanor",
+ "demilitarisation": "demilitarization",
+ "demilitarise": "demilitarize",
+ "demilitarised": "demilitarized",
+ "demilitarises": "demilitarizes",
+ "demilitarising": "demilitarizing",
+ "demobilisation": "demobilization",
+ "demobilise": "demobilize",
+ "demobilised": "demobilized",
+ "demobilises": "demobilizes",
+ "demobilising": "demobilizing",
+ "democratisation": "democratization",
+ "democratise": "democratize",
+ "democratised": "democratized",
+ "democratises": "democratizes",
+ "democratising": "democratizing",
+ "demonise": "demonize",
+ "demonised": "demonized",
+ "demonises": "demonizes",
+ "demonising": "demonizing",
+ "demoralisation": "demoralization",
+ "demoralise": "demoralize",
+ "demoralised": "demoralized",
+ "demoralises": "demoralizes",
+ "demoralising": "demoralizing",
+ "denationalisation": "denationalization",
+ "denationalise": "denationalize",
+ "denationalised": "denationalized",
+ "denationalises": "denationalizes",
+ "denationalising": "denationalizing",
+ "deodorise": "deodorize",
+ "deodorised": "deodorized",
+ "deodorises": "deodorizes",
+ "deodorising": "deodorizing",
+ "depersonalise": "depersonalize",
+ "depersonalised": "depersonalized",
+ "depersonalises": "depersonalizes",
+ "depersonalising": "depersonalizing",
+ "deputise": "deputize",
+ "deputised": "deputized",
+ "deputises": "deputizes",
+ "deputising": "deputizing",
+ "desensitisation": "desensitization",
+ "desensitise": "desensitize",
+ "desensitised": "desensitized",
+ "desensitises": "desensitizes",
+ "desensitising": "desensitizing",
+ "destabilisation": "destabilization",
+ "destabilise": "destabilize",
+ "destabilised": "destabilized",
+ "destabilises": "destabilizes",
+ "destabilising": "destabilizing",
+ "dialled": "dialed",
+ "dialling": "dialing",
+ "dialogue": "dialog",
+ "dialogues": "dialogs",
+ "diarrhoea": "diarrhea",
+ "digitise": "digitize",
+ "digitised": "digitized",
+ "digitises": "digitizes",
+ "digitising": "digitizing",
+ "disc": "disk",
+ "discolour": "discolor",
+ "discoloured": "discolored",
+ "discolouring": "discoloring",
+ "discolours": "discolors",
+ "discs": "disks",
+ "disembowelled": "disemboweled",
+ "disembowelling": "disemboweling",
+ "disfavour": "disfavor",
+ "dishevelled": "disheveled",
+ "dishonour": "dishonor",
+ "dishonourable": "dishonorable",
+ "dishonourably": "dishonorably",
+ "dishonoured": "dishonored",
+ "dishonouring": "dishonoring",
+ "dishonours": "dishonors",
+ "disorganisation": "disorganization",
+ "disorganised": "disorganized",
+ "distil": "distill",
+ "distils": "distills",
+ "dramatisation": "dramatization",
+ "dramatisations": "dramatizations",
+ "dramatise": "dramatize",
+ "dramatised": "dramatized",
+ "dramatises": "dramatizes",
+ "dramatising": "dramatizing",
+ "draught": "draft",
+ "draughtboard": "draftboard",
+ "draughtboards": "draftboards",
+ "draughtier": "draftier",
+ "draughtiest": "draftiest",
+ "draughts": "drafts",
+ "draughtsman": "draftsman",
+ "draughtsmanship": "draftsmanship",
+ "draughtsmen": "draftsmen",
+ "draughtswoman": "draftswoman",
+ "draughtswomen": "draftswomen",
+ "draughty": "drafty",
+ "drivelled": "driveled",
+ "drivelling": "driveling",
+ "duelled": "dueled",
+ "duelling": "dueling",
+ "economise": "economize",
+ "economised": "economized",
+ "economises": "economizes",
+ "economising": "economizing",
+ "edoema": "edema",
+ "editorialise": "editorialize",
+ "editorialised": "editorialized",
+ "editorialises": "editorializes",
+ "editorialising": "editorializing",
+ "empathise": "empathize",
+ "empathised": "empathized",
+ "empathises": "empathizes",
+ "empathising": "empathizing",
+ "emphasise": "emphasize",
+ "emphasised": "emphasized",
+ "emphasises": "emphasizes",
+ "emphasising": "emphasizing",
+ "enamelled": "enameled",
+ "enamelling": "enameling",
+ "enamoured": "enamored",
+ "encyclopaedia": "encyclopedia",
+ "encyclopaedias": "encyclopedias",
+ "encyclopaedic": "encyclopedic",
+ "endeavour": "endeavor",
+ "endeavoured": "endeavored",
+ "endeavouring": "endeavoring",
+ "endeavours": "endeavors",
+ "energise": "energize",
+ "energised": "energized",
+ "energises": "energizes",
+ "energising": "energizing",
+ "enrol": "enroll",
+ "enrols": "enrolls",
+ "enthral": "enthrall",
+ "enthrals": "enthralls",
+ "epaulette": "epaulet",
+ "epaulettes": "epaulets",
+ "epicentre": "epicenter",
+ "epicentres": "epicenters",
+ "epilogue": "epilog",
+ "epilogues": "epilogs",
+ "epitomise": "epitomize",
+ "epitomised": "epitomized",
+ "epitomises": "epitomizes",
+ "epitomising": "epitomizing",
+ "equalisation": "equalization",
+ "equalise": "equalize",
+ "equalised": "equalized",
+ "equaliser": "equalizer",
+ "equalisers": "equalizers",
+ "equalises": "equalizes",
+ "equalising": "equalizing",
+ "eulogise": "eulogize",
+ "eulogised": "eulogized",
+ "eulogises": "eulogizes",
+ "eulogising": "eulogizing",
+ "evangelise": "evangelize",
+ "evangelised": "evangelized",
+ "evangelises": "evangelizes",
+ "evangelising": "evangelizing",
+ "exorcise": "exorcize",
+ "exorcised": "exorcized",
+ "exorcises": "exorcizes",
+ "exorcising": "exorcizing",
+ "extemporisation": "extemporization",
+ "extemporise": "extemporize",
+ "extemporised": "extemporized",
+ "extemporises": "extemporizes",
+ "extemporising": "extemporizing",
+ "externalisation": "externalization",
+ "externalisations": "externalizations",
+ "externalise": "externalize",
+ "externalised": "externalized",
+ "externalises": "externalizes",
+ "externalising": "externalizing",
+ "factorise": "factorize",
+ "factorised": "factorized",
+ "factorises": "factorizes",
+ "factorising": "factorizing",
+ "faecal": "fecal",
+ "faeces": "feces",
+ "familiarisation": "familiarization",
+ "familiarise": "familiarize",
+ "familiarised": "familiarized",
+ "familiarises": "familiarizes",
+ "familiarising": "familiarizing",
+ "fantasise": "fantasize",
+ "fantasised": "fantasized",
+ "fantasises": "fantasizes",
+ "fantasising": "fantasizing",
+ "favour": "favor",
+ "favourable": "favorable",
+ "favourably": "favorably",
+ "favoured": "favored",
+ "favouring": "favoring",
+ "favourite": "favorite",
+ "favourites": "favorites",
+ "favouritism": "favoritism",
+ "favours": "favors",
+ "feminise": "feminize",
+ "feminised": "feminized",
+ "feminises": "feminizes",
+ "feminising": "feminizing",
+ "fertilisation": "fertilization",
+ "fertilise": "fertilize",
+ "fertilised": "fertilized",
+ "fertiliser": "fertilizer",
+ "fertilisers": "fertilizers",
+ "fertilises": "fertilizes",
+ "fertilising": "fertilizing",
+ "fervour": "fervor",
+ "fibre": "fiber",
+ "fibreglass": "fiberglass",
+ "fibres": "fibers",
+ "fictionalisation": "fictionalization",
+ "fictionalisations": "fictionalizations",
+ "fictionalise": "fictionalize",
+ "fictionalised": "fictionalized",
+ "fictionalises": "fictionalizes",
+ "fictionalising": "fictionalizing",
+ "fillet": "filet",
+ "filleted": "fileted",
+ "filleting": "fileting",
+ "fillets": "filets",
+ "finalisation": "finalization",
+ "finalise": "finalize",
+ "finalised": "finalized",
+ "finalises": "finalizes",
+ "finalising": "finalizing",
+ "flautist": "flutist",
+ "flautists": "flutists",
+ "flavour": "flavor",
+ "flavoured": "flavored",
+ "flavouring": "flavoring",
+ "flavourings": "flavorings",
+ "flavourless": "flavorless",
+ "flavours": "flavors",
+ "flavoursome": "flavorsome",
+ "flyer / flier": "flier / flyer",
+ "foetal": "fetal",
+ "foetid": "fetid",
+ "foetus": "fetus",
+ "foetuses": "fetuses",
+ "formalisation": "formalization",
+ "formalise": "formalize",
+ "formalised": "formalized",
+ "formalises": "formalizes",
+ "formalising": "formalizing",
+ "fossilisation": "fossilization",
+ "fossilise": "fossilize",
+ "fossilised": "fossilized",
+ "fossilises": "fossilizes",
+ "fossilising": "fossilizing",
+ "fraternisation": "fraternization",
+ "fraternise": "fraternize",
+ "fraternised": "fraternized",
+ "fraternises": "fraternizes",
+ "fraternising": "fraternizing",
+ "fulfil": "fulfill",
+ "fulfilment": "fulfillment",
+ "fulfils": "fulfills",
+ "funnelled": "funneled",
+ "funnelling": "funneling",
+ "galvanise": "galvanize",
+ "galvanised": "galvanized",
+ "galvanises": "galvanizes",
+ "galvanising": "galvanizing",
+ "gambolled": "gamboled",
+ "gambolling": "gamboling",
+ "gaol": "jail",
+ "gaolbird": "jailbird",
+ "gaolbirds": "jailbirds",
+ "gaolbreak": "jailbreak",
+ "gaolbreaks": "jailbreaks",
+ "gaoled": "jailed",
+ "gaoler": "jailer",
+ "gaolers": "jailers",
+ "gaoling": "jailing",
+ "gaols": "jails",
+ "gasses": "gases",
+ "gage": "gauge",
+ "gaged": "gauged",
+ "gages": "gauges",
+ "gaging": "gauging",
+ "generalisation": "generalization",
+ "generalisations": "generalizations",
+ "generalise": "generalize",
+ "generalised": "generalized",
+ "generalises": "generalizes",
+ "generalising": "generalizing",
+ "ghettoise": "ghettoize",
+ "ghettoised": "ghettoized",
+ "ghettoises": "ghettoizes",
+ "ghettoising": "ghettoizing",
+ "gipsies": "gypsies",
+ "glamorise": "glamorize",
+ "glamorised": "glamorized",
+ "glamorises": "glamorizes",
+ "glamorising": "glamorizing",
+ "glamor": "glamour",
+ "globalisation": "globalization",
+ "globalise": "globalize",
+ "globalised": "globalized",
+ "globalises": "globalizes",
+ "globalising": "globalizing",
+ "glueing": "gluing",
+ "goitre": "goiter",
+ "goitres": "goiters",
+ "gonorrhoea": "gonorrhea",
+ "gramme": "gram",
+ "grammes": "grams",
+ "gravelled": "graveled",
+ "grey": "gray",
+ "greyed": "grayed",
+ "greying": "graying",
+ "greyish": "grayish",
+ "greyness": "grayness",
+ "greys": "grays",
+ "grovelled": "groveled",
+ "grovelling": "groveling",
+ "groyne": "groin",
+ "groynes": "groins",
+ "gruelling": "grueling",
+ "gruellingly": "gruelingly",
+ "gryphon": "griffin",
+ "gryphons": "griffins",
+ "gynaecological": "gynecological",
+ "gynaecologist": "gynecologist",
+ "gynaecologists": "gynecologists",
+ "gynaecology": "gynecology",
+ "haematological": "hematological",
+ "haematologist": "hematologist",
+ "haematologists": "hematologists",
+ "haematology": "hematology",
+ "haemoglobin": "hemoglobin",
+ "haemophilia": "hemophilia",
+ "haemophiliac": "hemophiliac",
+ "haemophiliacs": "hemophiliacs",
+ "haemorrhage": "hemorrhage",
+ "haemorrhaged": "hemorrhaged",
+ "haemorrhages": "hemorrhages",
+ "haemorrhaging": "hemorrhaging",
+ "haemorrhoids": "hemorrhoids",
+ "harbour": "harbor",
+ "harboured": "harbored",
+ "harbouring": "harboring",
+ "harbours": "harbors",
+ "harmonisation": "harmonization",
+ "harmonise": "harmonize",
+ "harmonised": "harmonized",
+ "harmonises": "harmonizes",
+ "harmonising": "harmonizing",
+ "homoeopath": "homeopath",
+ "homoeopathic": "homeopathic",
+ "homoeopaths": "homeopaths",
+ "homoeopathy": "homeopathy",
+ "homogenise": "homogenize",
+ "homogenised": "homogenized",
+ "homogenises": "homogenizes",
+ "homogenising": "homogenizing",
+ "honour": "honor",
+ "honourable": "honorable",
+ "honourably": "honorably",
+ "honoured": "honored",
+ "honouring": "honoring",
+ "honours": "honors",
+ "hospitalisation": "hospitalization",
+ "hospitalise": "hospitalize",
+ "hospitalised": "hospitalized",
+ "hospitalises": "hospitalizes",
+ "hospitalising": "hospitalizing",
+ "humanise": "humanize",
+ "humanised": "humanized",
+ "humanises": "humanizes",
+ "humanising": "humanizing",
+ "humour": "humor",
+ "humoured": "humored",
+ "humouring": "humoring",
+ "humourless": "humorless",
+ "humours": "humors",
+ "hybridise": "hybridize",
+ "hybridised": "hybridized",
+ "hybridises": "hybridizes",
+ "hybridising": "hybridizing",
+ "hypnotise": "hypnotize",
+ "hypnotised": "hypnotized",
+ "hypnotises": "hypnotizes",
+ "hypnotising": "hypnotizing",
+ "hypothesise": "hypothesize",
+ "hypothesised": "hypothesized",
+ "hypothesises": "hypothesizes",
+ "hypothesising": "hypothesizing",
+ "idealisation": "idealization",
+ "idealise": "idealize",
+ "idealised": "idealized",
+ "idealises": "idealizes",
+ "idealising": "idealizing",
+ "idolise": "idolize",
+ "idolised": "idolized",
+ "idolises": "idolizes",
+ "idolising": "idolizing",
+ "immobilisation": "immobilization",
+ "immobilise": "immobilize",
+ "immobilised": "immobilized",
+ "immobiliser": "immobilizer",
+ "immobilisers": "immobilizers",
+ "immobilises": "immobilizes",
+ "immobilising": "immobilizing",
+ "immortalise": "immortalize",
+ "immortalised": "immortalized",
+ "immortalises": "immortalizes",
+ "immortalising": "immortalizing",
+ "immunisation": "immunization",
+ "immunise": "immunize",
+ "immunised": "immunized",
+ "immunises": "immunizes",
+ "immunising": "immunizing",
+ "impanelled": "impaneled",
+ "impanelling": "impaneling",
+ "imperilled": "imperiled",
+ "imperilling": "imperiling",
+ "individualise": "individualize",
+ "individualised": "individualized",
+ "individualises": "individualizes",
+ "individualising": "individualizing",
+ "industrialise": "industrialize",
+ "industrialised": "industrialized",
+ "industrialises": "industrializes",
+ "industrialising": "industrializing",
+ "inflexion": "inflection",
+ "inflexions": "inflections",
+ "initialise": "initialize",
+ "initialised": "initialized",
+ "initialises": "initializes",
+ "initialising": "initializing",
+ "initialled": "initialed",
+ "initialling": "initialing",
+ "instal": "install",
+ "instalment": "installment",
+ "instalments": "installments",
+ "instals": "installs",
+ "instil": "instill",
+ "instils": "instills",
+ "institutionalisation": "institutionalization",
+ "institutionalise": "institutionalize",
+ "institutionalised": "institutionalized",
+ "institutionalises": "institutionalizes",
+ "institutionalising": "institutionalizing",
+ "intellectualise": "intellectualize",
+ "intellectualised": "intellectualized",
+ "intellectualises": "intellectualizes",
+ "intellectualising": "intellectualizing",
+ "internalisation": "internalization",
+ "internalise": "internalize",
+ "internalised": "internalized",
+ "internalises": "internalizes",
+ "internalising": "internalizing",
+ "internationalisation": "internationalization",
+ "internationalise": "internationalize",
+ "internationalised": "internationalized",
+ "internationalises": "internationalizes",
+ "internationalising": "internationalizing",
+ "ionisation": "ionization",
+ "ionise": "ionize",
+ "ionised": "ionized",
+ "ioniser": "ionizer",
+ "ionisers": "ionizers",
+ "ionises": "ionizes",
+ "ionising": "ionizing",
+ "italicise": "italicize",
+ "italicised": "italicized",
+ "italicises": "italicizes",
+ "italicising": "italicizing",
+ "itemise": "itemize",
+ "itemised": "itemized",
+ "itemises": "itemizes",
+ "itemising": "itemizing",
+ "jeopardise": "jeopardize",
+ "jeopardised": "jeopardized",
+ "jeopardises": "jeopardizes",
+ "jeopardising": "jeopardizing",
+ "jewelled": "jeweled",
+ "jeweller": "jeweler",
+ "jewellers": "jewelers",
+ "jewellery": "jewelry",
+ "judgement": "judgment",
+ "kilogramme": "kilogram",
+ "kilogrammes": "kilograms",
+ "kilometre": "kilometer",
+ "kilometres": "kilometers",
+ "labelled": "labeled",
+ "labelling": "labeling",
+ "labour": "labor",
+ "laboured": "labored",
+ "labourer": "laborer",
+ "labourers": "laborers",
+ "labouring": "laboring",
+ "labours": "labors",
+ "lacklustre": "lackluster",
+ "legalisation": "legalization",
+ "legalise": "legalize",
+ "legalised": "legalized",
+ "legalises": "legalizes",
+ "legalising": "legalizing",
+ "legitimise": "legitimize",
+ "legitimised": "legitimized",
+ "legitimises": "legitimizes",
+ "legitimising": "legitimizing",
+ "leukaemia": "leukemia",
+ "levelled": "leveled",
+ "leveller": "leveler",
+ "levellers": "levelers",
+ "levelling": "leveling",
+ "libelled": "libeled",
+ "libelling": "libeling",
+ "libellous": "libelous",
+ "liberalisation": "liberalization",
+ "liberalise": "liberalize",
+ "liberalised": "liberalized",
+ "liberalises": "liberalizes",
+ "liberalising": "liberalizing",
+ "licence": "license",
+ "licenced": "licensed",
+ "licences": "licenses",
+ "licencing": "licensing",
+ "likeable": "likable",
+ "lionisation": "lionization",
+ "lionise": "lionize",
+ "lionised": "lionized",
+ "lionises": "lionizes",
+ "lionising": "lionizing",
+ "liquidise": "liquidize",
+ "liquidised": "liquidized",
+ "liquidiser": "liquidizer",
+ "liquidisers": "liquidizers",
+ "liquidises": "liquidizes",
+ "liquidising": "liquidizing",
+ "litre": "liter",
+ "litres": "liters",
+ "localise": "localize",
+ "localised": "localized",
+ "localises": "localizes",
+ "localising": "localizing",
+ "louvre": "louver",
+ "louvred": "louvered",
+ "louvres": "louvers",
+ "lustre": "luster",
+ "magnetise": "magnetize",
+ "magnetised": "magnetized",
+ "magnetises": "magnetizes",
+ "magnetising": "magnetizing",
+ "manoeuvrability": "maneuverability",
+ "manoeuvrable": "maneuverable",
+ "manoeuvre": "maneuver",
+ "manoeuvred": "maneuvered",
+ "manoeuvres": "maneuvers",
+ "manoeuvring": "maneuvering",
+ "manoeuvrings": "maneuverings",
+ "marginalisation": "marginalization",
+ "marginalise": "marginalize",
+ "marginalised": "marginalized",
+ "marginalises": "marginalizes",
+ "marginalising": "marginalizing",
+ "marshalled": "marshaled",
+ "marshalling": "marshaling",
+ "marvelled": "marveled",
+ "marvelling": "marveling",
+ "marvellous": "marvelous",
+ "marvellously": "marvelously",
+ "materialisation": "materialization",
+ "materialise": "materialize",
+ "materialised": "materialized",
+ "materialises": "materializes",
+ "materialising": "materializing",
+ "maximisation": "maximization",
+ "maximise": "maximize",
+ "maximised": "maximized",
+ "maximises": "maximizes",
+ "maximising": "maximizing",
+ "meagre": "meager",
+ "mechanisation": "mechanization",
+ "mechanise": "mechanize",
+ "mechanised": "mechanized",
+ "mechanises": "mechanizes",
+ "mechanising": "mechanizing",
+ "mediaeval": "medieval",
+ "memorialise": "memorialize",
+ "memorialised": "memorialized",
+ "memorialises": "memorializes",
+ "memorialising": "memorializing",
+ "memorise": "memorize",
+ "memorised": "memorized",
+ "memorises": "memorizes",
+ "memorising": "memorizing",
+ "mesmerise": "mesmerize",
+ "mesmerised": "mesmerized",
+ "mesmerises": "mesmerizes",
+ "mesmerising": "mesmerizing",
+ "metabolise": "metabolize",
+ "metabolised": "metabolized",
+ "metabolises": "metabolizes",
+ "metabolising": "metabolizing",
+ "metre": "meter",
+ "metres": "meters",
+ "micrometre": "micrometer",
+ "micrometres": "micrometers",
+ "militarise": "militarize",
+ "militarised": "militarized",
+ "militarises": "militarizes",
+ "militarising": "militarizing",
+ "milligramme": "milligram",
+ "milligrammes": "milligrams",
+ "millilitre": "milliliter",
+ "millilitres": "milliliters",
+ "millimetre": "millimeter",
+ "millimetres": "millimeters",
+ "miniaturisation": "miniaturization",
+ "miniaturise": "miniaturize",
+ "miniaturised": "miniaturized",
+ "miniaturises": "miniaturizes",
+ "miniaturising": "miniaturizing",
+ "minibusses": "minibuses",
+ "minimise": "minimize",
+ "minimised": "minimized",
+ "minimises": "minimizes",
+ "minimising": "minimizing",
+ "misbehaviour": "misbehavior",
+ "misdemeanour": "misdemeanor",
+ "misdemeanours": "misdemeanors",
+ "misspelt": "misspelled",
+ "mitre": "miter",
+ "mitres": "miters",
+ "mobilisation": "mobilization",
+ "mobilise": "mobilize",
+ "mobilised": "mobilized",
+ "mobilises": "mobilizes",
+ "mobilising": "mobilizing",
+ "modelled": "modeled",
+ "modeller": "modeler",
+ "modellers": "modelers",
+ "modelling": "modeling",
+ "modernise": "modernize",
+ "modernised": "modernized",
+ "modernises": "modernizes",
+ "modernising": "modernizing",
+ "moisturise": "moisturize",
+ "moisturised": "moisturized",
+ "moisturiser": "moisturizer",
+ "moisturisers": "moisturizers",
+ "moisturises": "moisturizes",
+ "moisturising": "moisturizing",
+ "monologue": "monolog",
+ "monologues": "monologs",
+ "monopolisation": "monopolization",
+ "monopolise": "monopolize",
+ "monopolised": "monopolized",
+ "monopolises": "monopolizes",
+ "monopolising": "monopolizing",
+ "moralise": "moralize",
+ "moralised": "moralized",
+ "moralises": "moralizes",
+ "moralising": "moralizing",
+ "motorised": "motorized",
+ "mould": "mold",
+ "moulded": "molded",
+ "moulder": "molder",
+ "mouldered": "moldered",
+ "mouldering": "moldering",
+ "moulders": "molders",
+ "mouldier": "moldier",
+ "mouldiest": "moldiest",
+ "moulding": "molding",
+ "mouldings": "moldings",
+ "moulds": "molds",
+ "mouldy": "moldy",
+ "moult": "molt",
+ "moulted": "molted",
+ "moulting": "molting",
+ "moults": "molts",
+ "moustache": "mustache",
+ "moustached": "mustached",
+ "moustaches": "mustaches",
+ "moustachioed": "mustachioed",
+ "multicoloured": "multicolored",
+ "nationalisation": "nationalization",
+ "nationalisations": "nationalizations",
+ "nationalise": "nationalize",
+ "nationalised": "nationalized",
+ "nationalises": "nationalizes",
+ "nationalising": "nationalizing",
+ "naturalisation": "naturalization",
+ "naturalise": "naturalize",
+ "naturalised": "naturalized",
+ "naturalises": "naturalizes",
+ "naturalising": "naturalizing",
+ "neighbour": "neighbor",
+ "neighbourhood": "neighborhood",
+ "neighbourhoods": "neighborhoods",
+ "neighbouring": "neighboring",
+ "neighbourliness": "neighborliness",
+ "neighbourly": "neighborly",
+ "neighbours": "neighbors",
+ "neutralisation": "neutralization",
+ "neutralise": "neutralize",
+ "neutralised": "neutralized",
+ "neutralises": "neutralizes",
+ "neutralising": "neutralizing",
+ "normalisation": "normalization",
+ "normalise": "normalize",
+ "normalised": "normalized",
+ "normalises": "normalizes",
+ "normalising": "normalizing",
+ "odour": "odor",
+ "odourless": "odorless",
+ "odours": "odors",
+ "oesophagus": "esophagus",
+ "oesophaguses": "esophaguses",
+ "oestrogen": "estrogen",
+ "offence": "offense",
+ "offences": "offenses",
+ "omelette": "omelet",
+ "omelettes": "omelets",
+ "optimise": "optimize",
+ "optimised": "optimized",
+ "optimises": "optimizes",
+ "optimising": "optimizing",
+ "organisation": "organization",
+ "organisational": "organizational",
+ "organisations": "organizations",
+ "organise": "organize",
+ "organised": "organized",
+ "organiser": "organizer",
+ "organisers": "organizers",
+ "organises": "organizes",
+ "organising": "organizing",
+ "orthopaedic": "orthopedic",
+ "orthopaedics": "orthopedics",
+ "ostracise": "ostracize",
+ "ostracised": "ostracized",
+ "ostracises": "ostracizes",
+ "ostracising": "ostracizing",
+ "outmanoeuvre": "outmaneuver",
+ "outmanoeuvred": "outmaneuvered",
+ "outmanoeuvres": "outmaneuvers",
+ "outmanoeuvring": "outmaneuvering",
+ "overemphasise": "overemphasize",
+ "overemphasised": "overemphasized",
+ "overemphasises": "overemphasizes",
+ "overemphasising": "overemphasizing",
+ "oxidisation": "oxidization",
+ "oxidise": "oxidize",
+ "oxidised": "oxidized",
+ "oxidises": "oxidizes",
+ "oxidising": "oxidizing",
+ "paederast": "pederast",
+ "paederasts": "pederasts",
+ "paediatric": "pediatric",
+ "paediatrician": "pediatrician",
+ "paediatricians": "pediatricians",
+ "paediatrics": "pediatrics",
+ "paedophile": "pedophile",
+ "paedophiles": "pedophiles",
+ "paedophilia": "pedophilia",
+ "palaeolithic": "paleolithic",
+ "palaeontologist": "paleontologist",
+ "palaeontologists": "paleontologists",
+ "palaeontology": "paleontology",
+ "panelled": "paneled",
+ "panelling": "paneling",
+ "panellist": "panelist",
+ "panellists": "panelists",
+ "paralyse": "paralyze",
+ "paralysed": "paralyzed",
+ "paralyses": "paralyzes",
+ "paralysing": "paralyzing",
+ "parcelled": "parceled",
+ "parcelling": "parceling",
+ "parlour": "parlor",
+ "parlours": "parlors",
+ "particularise": "particularize",
+ "particularised": "particularized",
+ "particularises": "particularizes",
+ "particularising": "particularizing",
+ "passivisation": "passivization",
+ "passivise": "passivize",
+ "passivised": "passivized",
+ "passivises": "passivizes",
+ "passivising": "passivizing",
+ "pasteurisation": "pasteurization",
+ "pasteurise": "pasteurize",
+ "pasteurised": "pasteurized",
+ "pasteurises": "pasteurizes",
+ "pasteurising": "pasteurizing",
+ "patronise": "patronize",
+ "patronised": "patronized",
+ "patronises": "patronizes",
+ "patronising": "patronizing",
+ "patronisingly": "patronizingly",
+ "pedalled": "pedaled",
+ "pedalling": "pedaling",
+ "pedestrianisation": "pedestrianization",
+ "pedestrianise": "pedestrianize",
+ "pedestrianised": "pedestrianized",
+ "pedestrianises": "pedestrianizes",
+ "pedestrianising": "pedestrianizing",
+ "penalise": "penalize",
+ "penalised": "penalized",
+ "penalises": "penalizes",
+ "penalising": "penalizing",
+ "pencilled": "penciled",
+ "pencilling": "penciling",
+ "personalise": "personalize",
+ "personalised": "personalized",
+ "personalises": "personalizes",
+ "personalising": "personalizing",
+ "pharmacopoeia": "pharmacopeia",
+ "pharmacopoeias": "pharmacopeias",
+ "philosophise": "philosophize",
+ "philosophised": "philosophized",
+ "philosophises": "philosophizes",
+ "philosophising": "philosophizing",
+ "philtre": "filter",
+ "philtres": "filters",
+ "phoney": "phony",
+ "plagiarise": "plagiarize",
+ "plagiarised": "plagiarized",
+ "plagiarises": "plagiarizes",
+ "plagiarising": "plagiarizing",
+ "plough": "plow",
+ "ploughed": "plowed",
+ "ploughing": "plowing",
+ "ploughman": "plowman",
+ "ploughmen": "plowmen",
+ "ploughs": "plows",
+ "ploughshare": "plowshare",
+ "ploughshares": "plowshares",
+ "polarisation": "polarization",
+ "polarise": "polarize",
+ "polarised": "polarized",
+ "polarises": "polarizes",
+ "polarising": "polarizing",
+ "politicisation": "politicization",
+ "politicise": "politicize",
+ "politicised": "politicized",
+ "politicises": "politicizes",
+ "politicising": "politicizing",
+ "popularisation": "popularization",
+ "popularise": "popularize",
+ "popularised": "popularized",
+ "popularises": "popularizes",
+ "popularising": "popularizing",
+ "pouffe": "pouf",
+ "pouffes": "poufs",
+ "practise": "practice",
+ "practised": "practiced",
+ "practises": "practices",
+ "practising": "practicing",
+ "praesidium": "presidium",
+ "praesidiums": "presidiums",
+ "pressurisation": "pressurization",
+ "pressurise": "pressurize",
+ "pressurised": "pressurized",
+ "pressurises": "pressurizes",
+ "pressurising": "pressurizing",
+ "pretence": "pretense",
+ "pretences": "pretenses",
+ "primaeval": "primeval",
+ "prioritisation": "prioritization",
+ "prioritise": "prioritize",
+ "prioritised": "prioritized",
+ "prioritises": "prioritizes",
+ "prioritising": "prioritizing",
+ "privatisation": "privatization",
+ "privatisations": "privatizations",
+ "privatise": "privatize",
+ "privatised": "privatized",
+ "privatises": "privatizes",
+ "privatising": "privatizing",
+ "professionalisation": "professionalization",
+ "professionalise": "professionalize",
+ "professionalised": "professionalized",
+ "professionalises": "professionalizes",
+ "professionalising": "professionalizing",
+ "programme": "program",
+ "programmes": "programs",
+ "prologue": "prolog",
+ "prologues": "prologs",
+ "propagandise": "propagandize",
+ "propagandised": "propagandized",
+ "propagandises": "propagandizes",
+ "propagandising": "propagandizing",
+ "proselytise": "proselytize",
+ "proselytised": "proselytized",
+ "proselytiser": "proselytizer",
+ "proselytisers": "proselytizers",
+ "proselytises": "proselytizes",
+ "proselytising": "proselytizing",
+ "psychoanalyse": "psychoanalyze",
+ "psychoanalysed": "psychoanalyzed",
+ "psychoanalyses": "psychoanalyzes",
+ "psychoanalysing": "psychoanalyzing",
+ "publicise": "publicize",
+ "publicised": "publicized",
+ "publicises": "publicizes",
+ "publicising": "publicizing",
+ "pulverisation": "pulverization",
+ "pulverise": "pulverize",
+ "pulverised": "pulverized",
+ "pulverises": "pulverizes",
+ "pulverising": "pulverizing",
+ "pummelled": "pummel",
+ "pummelling": "pummeled",
+ "pyjama": "pajama",
+ "pyjamas": "pajamas",
+ "pzazz": "pizzazz",
+ "quarrelled": "quarreled",
+ "quarrelling": "quarreling",
+ "radicalise": "radicalize",
+ "radicalised": "radicalized",
+ "radicalises": "radicalizes",
+ "radicalising": "radicalizing",
+ "rancour": "rancor",
+ "randomise": "randomize",
+ "randomised": "randomized",
+ "randomises": "randomizes",
+ "randomising": "randomizing",
+ "rationalisation": "rationalization",
+ "rationalisations": "rationalizations",
+ "rationalise": "rationalize",
+ "rationalised": "rationalized",
+ "rationalises": "rationalizes",
+ "rationalising": "rationalizing",
+ "ravelled": "raveled",
+ "ravelling": "raveling",
+ "realisable": "realizable",
+ "realisation": "realization",
+ "realisations": "realizations",
+ "realise": "realize",
+ "realised": "realized",
+ "realises": "realizes",
+ "realising": "realizing",
+ "recognisable": "recognizable",
+ "recognisably": "recognizably",
+ "recognisance": "recognizance",
+ "recognise": "recognize",
+ "recognised": "recognized",
+ "recognises": "recognizes",
+ "recognising": "recognizing",
+ "reconnoitre": "reconnoiter",
+ "reconnoitred": "reconnoitered",
+ "reconnoitres": "reconnoiters",
+ "reconnoitring": "reconnoitering",
+ "refuelled": "refueled",
+ "refuelling": "refueling",
+ "regularisation": "regularization",
+ "regularise": "regularize",
+ "regularised": "regularized",
+ "regularises": "regularizes",
+ "regularising": "regularizing",
+ "remodelled": "remodeled",
+ "remodelling": "remodeling",
+ "remould": "remold",
+ "remoulded": "remolded",
+ "remoulding": "remolding",
+ "remoulds": "remolds",
+ "reorganisation": "reorganization",
+ "reorganisations": "reorganizations",
+ "reorganise": "reorganize",
+ "reorganised": "reorganized",
+ "reorganises": "reorganizes",
+ "reorganising": "reorganizing",
+ "revelled": "reveled",
+ "reveller": "reveler",
+ "revellers": "revelers",
+ "revelling": "reveling",
+ "revitalise": "revitalize",
+ "revitalised": "revitalized",
+ "revitalises": "revitalizes",
+ "revitalising": "revitalizing",
+ "revolutionise": "revolutionize",
+ "revolutionised": "revolutionized",
+ "revolutionises": "revolutionizes",
+ "revolutionising": "revolutionizing",
+ "rhapsodise": "rhapsodize",
+ "rhapsodised": "rhapsodized",
+ "rhapsodises": "rhapsodizes",
+ "rhapsodising": "rhapsodizing",
+ "rigour": "rigor",
+ "rigours": "rigors",
+ "ritualised": "ritualized",
+ "rivalled": "rivaled",
+ "rivalling": "rivaling",
+ "romanticise": "romanticize",
+ "romanticised": "romanticized",
+ "romanticises": "romanticizes",
+ "romanticising": "romanticizing",
+ "rumour": "rumor",
+ "rumoured": "rumored",
+ "rumours": "rumors",
+ "sabre": "saber",
+ "sabres": "sabers",
+ "saltpetre": "saltpeter",
+ "sanitise": "sanitize",
+ "sanitised": "sanitized",
+ "sanitises": "sanitizes",
+ "sanitising": "sanitizing",
+ "satirise": "satirize",
+ "satirised": "satirized",
+ "satirises": "satirizes",
+ "satirising": "satirizing",
+ "saviour": "savior",
+ "saviours": "saviors",
+ "savour": "savor",
+ "savoured": "savored",
+ "savouries": "savories",
+ "savouring": "savoring",
+ "savours": "savors",
+ "savoury": "savory",
+ "scandalise": "scandalize",
+ "scandalised": "scandalized",
+ "scandalises": "scandalizes",
+ "scandalising": "scandalizing",
+ "sceptic": "skeptic",
+ "sceptical": "skeptical",
+ "sceptically": "skeptically",
+ "scepticism": "skepticism",
+ "sceptics": "skeptics",
+ "sceptre": "scepter",
+ "sceptres": "scepters",
+ "scrutinise": "scrutinize",
+ "scrutinised": "scrutinized",
+ "scrutinises": "scrutinizes",
+ "scrutinising": "scrutinizing",
+ "secularisation": "secularization",
+ "secularise": "secularize",
+ "secularised": "secularized",
+ "secularises": "secularizes",
+ "secularising": "secularizing",
+ "sensationalise": "sensationalize",
+ "sensationalised": "sensationalized",
+ "sensationalises": "sensationalizes",
+ "sensationalising": "sensationalizing",
+ "sensitise": "sensitize",
+ "sensitised": "sensitized",
+ "sensitises": "sensitizes",
+ "sensitising": "sensitizing",
+ "sentimentalise": "sentimentalize",
+ "sentimentalised": "sentimentalized",
+ "sentimentalises": "sentimentalizes",
+ "sentimentalising": "sentimentalizing",
+ "sepulchre": "sepulcher",
+ "sepulchres": "sepulchers",
+ "serialisation": "serialization",
+ "serialisations": "serializations",
+ "serialise": "serialize",
+ "serialised": "serialized",
+ "serialises": "serializes",
+ "serialising": "serializing",
+ "sermonise": "sermonize",
+ "sermonised": "sermonized",
+ "sermonises": "sermonizes",
+ "sermonising": "sermonizing",
+ "sheikh": "sheik",
+ "shovelled": "shoveled",
+ "shovelling": "shoveling",
+ "shrivelled": "shriveled",
+ "shrivelling": "shriveling",
+ "signalise": "signalize",
+ "signalised": "signalized",
+ "signalises": "signalizes",
+ "signalising": "signalizing",
+ "signalled": "signaled",
+ "signalling": "signaling",
+ "smoulder": "smolder",
+ "smouldered": "smoldered",
+ "smouldering": "smoldering",
+ "smoulders": "smolders",
+ "snivelled": "sniveled",
+ "snivelling": "sniveling",
+ "snorkelled": "snorkeled",
+ "snorkelling": "snorkeling",
+ "snowplough": "snowplow",
+ "snowploughs": "snowplow",
+ "socialisation": "socialization",
+ "socialise": "socialize",
+ "socialised": "socialized",
+ "socialises": "socializes",
+ "socialising": "socializing",
+ "sodomise": "sodomize",
+ "sodomised": "sodomized",
+ "sodomises": "sodomizes",
+ "sodomising": "sodomizing",
+ "solemnise": "solemnize",
+ "solemnised": "solemnized",
+ "solemnises": "solemnizes",
+ "solemnising": "solemnizing",
+ "sombre": "somber",
+ "specialisation": "specialization",
+ "specialisations": "specializations",
+ "specialise": "specialize",
+ "specialised": "specialized",
+ "specialises": "specializes",
+ "specialising": "specializing",
+ "spectre": "specter",
+ "spectres": "specters",
+ "spiralled": "spiraled",
+ "spiralling": "spiraling",
+ "splendour": "splendor",
+ "splendours": "splendors",
+ "squirrelled": "squirreled",
+ "squirrelling": "squirreling",
+ "stabilisation": "stabilization",
+ "stabilise": "stabilize",
+ "stabilised": "stabilized",
+ "stabiliser": "stabilizer",
+ "stabilisers": "stabilizers",
+ "stabilises": "stabilizes",
+ "stabilising": "stabilizing",
+ "standardisation": "standardization",
+ "standardise": "standardize",
+ "standardised": "standardized",
+ "standardises": "standardizes",
+ "standardising": "standardizing",
+ "stencilled": "stenciled",
+ "stencilling": "stenciling",
+ "sterilisation": "sterilization",
+ "sterilisations": "sterilizations",
+ "sterilise": "sterilize",
+ "sterilised": "sterilized",
+ "steriliser": "sterilizer",
+ "sterilisers": "sterilizers",
+ "sterilises": "sterilizes",
+ "sterilising": "sterilizing",
+ "stigmatisation": "stigmatization",
+ "stigmatise": "stigmatize",
+ "stigmatised": "stigmatized",
+ "stigmatises": "stigmatizes",
+ "stigmatising": "stigmatizing",
+ "storey": "story",
+ "storeys": "stories",
+ "subsidisation": "subsidization",
+ "subsidise": "subsidize",
+ "subsidised": "subsidized",
+ "subsidiser": "subsidizer",
+ "subsidisers": "subsidizers",
+ "subsidises": "subsidizes",
+ "subsidising": "subsidizing",
+ "succour": "succor",
+ "succoured": "succored",
+ "succouring": "succoring",
+ "succours": "succors",
+ "sulphate": "sulfate",
+ "sulphates": "sulfates",
+ "sulphide": "sulfide",
+ "sulphides": "sulfides",
+ "sulphur": "sulfur",
+ "sulphurous": "sulfurous",
+ "summarise": "summarize",
+ "summarised": "summarized",
+ "summarises": "summarizes",
+ "summarising": "summarizing",
+ "swivelled": "swiveled",
+ "swivelling": "swiveling",
+ "symbolise": "symbolize",
+ "symbolised": "symbolized",
+ "symbolises": "symbolizes",
+ "symbolising": "symbolizing",
+ "sympathise": "sympathize",
+ "sympathised": "sympathized",
+ "sympathiser": "sympathizer",
+ "sympathisers": "sympathizers",
+ "sympathises": "sympathizes",
+ "sympathising": "sympathizing",
+ "synchronisation": "synchronization",
+ "synchronise": "synchronize",
+ "synchronised": "synchronized",
+ "synchronises": "synchronizes",
+ "synchronising": "synchronizing",
+ "synthesise": "synthesize",
+ "synthesised": "synthesized",
+ "synthesiser": "synthesizer",
+ "synthesisers": "synthesizers",
+ "synthesises": "synthesizes",
+ "synthesising": "synthesizing",
+ "syphon": "siphon",
+ "syphoned": "siphoned",
+ "syphoning": "siphoning",
+ "syphons": "siphons",
+ "systematisation": "systematization",
+ "systematise": "systematize",
+ "systematised": "systematized",
+ "systematises": "systematizes",
+ "systematising": "systematizing",
+ "tantalise": "tantalize",
+ "tantalised": "tantalized",
+ "tantalises": "tantalizes",
+ "tantalising": "tantalizing",
+ "tantalisingly": "tantalizingly",
+ "tasselled": "tasseled",
+ "technicolour": "technicolor",
+ "temporise": "temporize",
+ "temporised": "temporized",
+ "temporises": "temporizes",
+ "temporising": "temporizing",
+ "tenderise": "tenderize",
+ "tenderised": "tenderized",
+ "tenderises": "tenderizes",
+ "tenderising": "tenderizing",
+ "terrorise": "terrorize",
+ "terrorised": "terrorized",
+ "terrorises": "terrorizes",
+ "terrorising": "terrorizing",
+ "theatre": "theater",
+ "theatregoer": "theatergoer",
+ "theatregoers": "theatergoers",
+ "theatres": "theaters",
+ "theorise": "theorize",
+ "theorised": "theorized",
+ "theorises": "theorizes",
+ "theorising": "theorizing",
+ "tonne": "ton",
+ "tonnes": "tons",
+ "towelled": "toweled",
+ "towelling": "toweling",
+ "toxaemia": "toxemia",
+ "tranquillise": "tranquilize",
+ "tranquillised": "tranquilized",
+ "tranquilliser": "tranquilizer",
+ "tranquillisers": "tranquilizers",
+ "tranquillises": "tranquilizes",
+ "tranquillising": "tranquilizing",
+ "tranquillity": "tranquility",
+ "tranquillize": "tranquilize",
+ "tranquillized": "tranquilized",
+ "tranquillizer": "tranquilizer",
+ "tranquillizers": "tranquilizers",
+ "tranquillizes": "tranquilizes",
+ "tranquillizing": "tranquilizing",
+ "tranquilly": "tranquility",
+ "transistorised": "transistorized",
+ "traumatise": "traumatize",
+ "traumatised": "traumatized",
+ "traumatises": "traumatizes",
+ "traumatising": "traumatizing",
+ "travelled": "traveled",
+ "traveller": "traveler",
+ "travellers": "travelers",
+ "travelling": "traveling",
+ "travelog": "travelogue",
+ "travelogs": "travelogues",
+ "trialled": "trialed",
+ "trialling": "trialing",
+ "tricolour": "tricolor",
+ "tricolours": "tricolors",
+ "trivialise": "trivialize",
+ "trivialised": "trivialized",
+ "trivialises": "trivializes",
+ "trivialising": "trivializing",
+ "tumour": "tumor",
+ "tumours": "tumors",
+ "tunnelled": "tunneled",
+ "tunnelling": "tunneling",
+ "tyrannise": "tyrannize",
+ "tyrannised": "tyrannized",
+ "tyrannises": "tyrannizes",
+ "tyrannising": "tyrannizing",
+ "tyre": "tire",
+ "tyres": "tires",
+ "unauthorised": "unauthorized",
+ "uncivilised": "uncivilized",
+ "underutilised": "underutilized",
+ "unequalled": "unequaled",
+ "unfavourable": "unfavorable",
+ "unfavourably": "unfavorably",
+ "unionisation": "unionization",
+ "unionise": "unionize",
+ "unionised": "unionized",
+ "unionises": "unionizes",
+ "unionising": "unionizing",
+ "unorganised": "unorganized",
+ "unravelled": "unraveled",
+ "unravelling": "unraveling",
+ "unrecognisable": "unrecognizable",
+ "unrecognised": "unrecognized",
+ "unrivalled": "unrivaled",
+ "unsavoury": "unsavory",
+ "untrammelled": "untrammeled",
+ "urbanisation": "urbanization",
+ "urbanise": "urbanize",
+ "urbanised": "urbanized",
+ "urbanises": "urbanizes",
+ "urbanising": "urbanizing",
+ "utilisable": "utilizable",
+ "utilisation": "utilization",
+ "utilise": "utilize",
+ "utilised": "utilized",
+ "utilises": "utilizes",
+ "utilising": "utilizing",
+ "valour": "valor",
+ "vandalise": "vandalize",
+ "vandalised": "vandalized",
+ "vandalises": "vandalizes",
+ "vandalising": "vandalizing",
+ "vaporisation": "vaporization",
+ "vaporise": "vaporize",
+ "vaporised": "vaporized",
+ "vaporises": "vaporizes",
+ "vaporising": "vaporizing",
+ "vapour": "vapor",
+ "vapours": "vapors",
+ "verbalise": "verbalize",
+ "verbalised": "verbalized",
+ "verbalises": "verbalizes",
+ "verbalising": "verbalizing",
+ "victimisation": "victimization",
+ "victimise": "victimize",
+ "victimised": "victimized",
+ "victimises": "victimizes",
+ "victimising": "victimizing",
+ "videodisc": "videodisk",
+ "videodiscs": "videodisks",
+ "vigour": "vigor",
+ "visualisation": "visualization",
+ "visualisations": "visualizations",
+ "visualise": "visualize",
+ "visualised": "visualized",
+ "visualises": "visualizes",
+ "visualising": "visualizing",
+ "vocalisation": "vocalization",
+ "vocalisations": "vocalizations",
+ "vocalise": "vocalize",
+ "vocalised": "vocalized",
+ "vocalises": "vocalizes",
+ "vocalising": "vocalizing",
+ "vulcanised": "vulcanized",
+ "vulgarisation": "vulgarization",
+ "vulgarise": "vulgarize",
+ "vulgarised": "vulgarized",
+ "vulgarises": "vulgarizes",
+ "vulgarising": "vulgarizing",
+ "waggon": "wagon",
+ "waggons": "wagons",
+ "watercolour": "watercolor",
+ "watercolours": "watercolors",
+ "weaselled": "weaseled",
+ "weaselling": "weaseling",
+ "westernisation": "westernization",
+ "westernise": "westernize",
+ "westernised": "westernized",
+ "westernises": "westernizes",
+ "westernising": "westernizing",
+ "womanise": "womanize",
+ "womanised": "womanized",
+ "womaniser": "womanizer",
+ "womanisers": "womanizers",
+ "womanises": "womanizes",
+ "womanising": "womanizing",
+ "woollen": "woolen",
+ "woollens": "woolens",
+ "woollies": "woolies",
+ "woolly": "wooly",
+ "worshipped": "worshiped",
+ "worshipping": "worshiping",
+ "worshipper": "worshiper",
+ "yodelled": "yodeled",
+ "yodelling": "yodeling",
+ "yoghourt": "yogurt",
+ "yoghourts": "yogurts",
+ "yoghurt": "yogurt",
+ "yoghurts": "yogurts",
+ "mhm": "hmm",
+ "mmm": "hmm"
+}
\ No newline at end of file
diff --git a/lmms_eval/tasks/gigaspeech/whisper_normalizer/english.py b/lmms_eval/tasks/gigaspeech/whisper_normalizer/english.py
new file mode 100644
index 00000000..7102f2ec
--- /dev/null
+++ b/lmms_eval/tasks/gigaspeech/whisper_normalizer/english.py
@@ -0,0 +1,529 @@
+import json
+import os
+import re
+from fractions import Fraction
+from typing import Iterator, List, Match, Optional, Union
+
+from more_itertools import windowed # TODO: new package
+
+from .basic import remove_symbols_and_diacritics
+
+
+class EnglishNumberNormalizer:
+ """
+ Convert any spelled-out numbers into arabic numbers, while handling:
+
+ - remove any commas
+ - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
+ - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
+ - spell out `one` and `ones`
+ - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self.zeros = {"o", "oh", "zero"}
+ self.ones = {
+ name: i
+ for i, name in enumerate(
+ [
+ "one",
+ "two",
+ "three",
+ "four",
+ "five",
+ "six",
+ "seven",
+ "eight",
+ "nine",
+ "ten",
+ "eleven",
+ "twelve",
+ "thirteen",
+ "fourteen",
+ "fifteen",
+ "sixteen",
+ "seventeen",
+ "eighteen",
+ "nineteen",
+ ],
+ start=1,
+ )
+ }
+ self.ones_plural = {"sixes" if name == "six" else name + "s": (value, "s") for name, value in self.ones.items()}
+ self.ones_ordinal = {
+ "zeroth": (0, "th"),
+ "first": (1, "st"),
+ "second": (2, "nd"),
+ "third": (3, "rd"),
+ "fifth": (5, "th"),
+ "twelfth": (12, "th"),
+ **{name + ("h" if name.endswith("t") else "th"): (value, "th") for name, value in self.ones.items() if value > 3 and value != 5 and value != 12},
+ }
+ self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
+
+ self.tens = {
+ "twenty": 20,
+ "thirty": 30,
+ "forty": 40,
+ "fifty": 50,
+ "sixty": 60,
+ "seventy": 70,
+ "eighty": 80,
+ "ninety": 90,
+ }
+ self.tens_plural = {name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()}
+ self.tens_ordinal = {name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()}
+ self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
+
+ self.multipliers = {
+ "hundred": 100,
+ "thousand": 1_000,
+ "million": 1_000_000,
+ "billion": 1_000_000_000,
+ "trillion": 1_000_000_000_000,
+ "quadrillion": 1_000_000_000_000_000,
+ "quintillion": 1_000_000_000_000_000_000,
+ "sextillion": 1_000_000_000_000_000_000_000,
+ "septillion": 1_000_000_000_000_000_000_000_000,
+ "octillion": 1_000_000_000_000_000_000_000_000_000,
+ "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
+ "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
+ }
+ self.multipliers_plural = {name + "s": (value, "s") for name, value in self.multipliers.items()}
+ self.multipliers_ordinal = {name + "th": (value, "th") for name, value in self.multipliers.items()}
+ self.multipliers_suffixed = {
+ **self.multipliers_plural,
+ **self.multipliers_ordinal,
+ }
+ self.decimals = {*self.ones, *self.tens, *self.zeros}
+
+ self.preceding_prefixers = {
+ "minus": "-",
+ "negative": "-",
+ "plus": "+",
+ "positive": "+",
+ }
+ self.following_prefixers = {
+ "pound": "ยฃ",
+ "pounds": "ยฃ",
+ "euro": "โฌ",
+ "euros": "โฌ",
+ "dollar": "$",
+ "dollars": "$",
+ "cent": "ยข",
+ "cents": "ยข",
+ }
+ self.prefixes = set(list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()))
+ self.suffixers = {
+ "per": {"cent": "%"},
+ "percent": "%",
+ }
+ self.specials = {"and", "double", "triple", "point"}
+
+ self.words = set(
+ [
+ key
+ for mapping in [
+ self.zeros,
+ self.ones,
+ self.ones_suffixed,
+ self.tens,
+ self.tens_suffixed,
+ self.multipliers,
+ self.multipliers_suffixed,
+ self.preceding_prefixers,
+ self.following_prefixers,
+ self.suffixers,
+ self.specials,
+ ]
+ for key in mapping
+ ]
+ )
+ self.literal_words = {"one", "ones"}
+
+ def process_words(self, words: List[str]) -> Iterator[str]:
+ prefix: Optional[str] = None
+ value: Optional[Union[str, int]] = None
+ skip = False
+
+ def to_fraction(s: str):
+ try:
+ return Fraction(s)
+ except ValueError:
+ return None
+
+ def output(result: Union[str, int]):
+ nonlocal prefix, value
+ result = str(result)
+ if prefix is not None:
+ result = prefix + result
+ value = None
+ prefix = None
+ return result
+
+ if len(words) == 0:
+ return
+
+ for prev, current, next in windowed([None] + words + [None], 3):
+ if skip:
+ skip = False
+ continue
+
+ next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
+ has_prefix = current[0] in self.prefixes
+ current_without_prefix = current[1:] if has_prefix else current
+ if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
+ # arabic numbers (potentially with signs and fractions)
+ f = to_fraction(current_without_prefix)
+ assert f is not None
+ if value is not None:
+ if isinstance(value, str) and value.endswith("."):
+ # concatenate decimals / ip address components
+ value = str(value) + str(current)
+ continue
+ else:
+ yield output(value)
+
+ prefix = current[0] if has_prefix else prefix
+ if f.denominator == 1:
+ value = f.numerator # store integers as int
+ else:
+ value = current_without_prefix
+ elif current not in self.words:
+ # non-numeric words
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current in self.zeros:
+ value = str(value or "") + "0"
+ elif current in self.ones:
+ ones = self.ones[current]
+
+ if value is None:
+ value = ones
+ elif isinstance(value, str) or prev in self.ones:
+ if prev in self.tens and ones < 10: # replace the last zero with the digit
+ assert value[-1] == "0"
+ value = value[:-1] + str(ones)
+ else:
+ value = str(value) + str(ones)
+ elif ones < 10:
+ if value % 10 == 0:
+ value += ones
+ else:
+ value = str(value) + str(ones)
+ else: # eleven to nineteen
+ if value % 100 == 0:
+ value += ones
+ else:
+ value = str(value) + str(ones)
+ elif current in self.ones_suffixed:
+ # ordinal or cardinal; yield the number right away
+ ones, suffix = self.ones_suffixed[current]
+ if value is None:
+ yield output(str(ones) + suffix)
+ elif isinstance(value, str) or prev in self.ones:
+ if prev in self.tens and ones < 10:
+ assert value[-1] == "0"
+ yield output(value[:-1] + str(ones) + suffix)
+ else:
+ yield output(str(value) + str(ones) + suffix)
+ elif ones < 10:
+ if value % 10 == 0:
+ yield output(str(value + ones) + suffix)
+ else:
+ yield output(str(value) + str(ones) + suffix)
+ else: # eleven to nineteen
+ if value % 100 == 0:
+ yield output(str(value + ones) + suffix)
+ else:
+ yield output(str(value) + str(ones) + suffix)
+ value = None
+ elif current in self.tens:
+ tens = self.tens[current]
+ if value is None:
+ value = tens
+ elif isinstance(value, str):
+ value = str(value) + str(tens)
+ else:
+ if value % 100 == 0:
+ value += tens
+ else:
+ value = str(value) + str(tens)
+ elif current in self.tens_suffixed:
+ # ordinal or cardinal; yield the number right away
+ tens, suffix = self.tens_suffixed[current]
+ if value is None:
+ yield output(str(tens) + suffix)
+ elif isinstance(value, str):
+ yield output(str(value) + str(tens) + suffix)
+ else:
+ if value % 100 == 0:
+ yield output(str(value + tens) + suffix)
+ else:
+ yield output(str(value) + str(tens) + suffix)
+ elif current in self.multipliers:
+ multiplier = self.multipliers[current]
+ if value is None:
+ value = multiplier
+ elif isinstance(value, str) or value == 0:
+ f = to_fraction(value)
+ p = f * multiplier if f is not None else None
+ if f is not None and p.denominator == 1:
+ value = p.numerator
+ else:
+ yield output(value)
+ value = multiplier
+ else:
+ before = value // 1000 * 1000
+ residual = value % 1000
+ value = before + residual * multiplier
+ elif current in self.multipliers_suffixed:
+ multiplier, suffix = self.multipliers_suffixed[current]
+ if value is None:
+ yield output(str(multiplier) + suffix)
+ elif isinstance(value, str):
+ f = to_fraction(value)
+ p = f * multiplier if f is not None else None
+ if f is not None and p.denominator == 1:
+ yield output(str(p.numerator) + suffix)
+ else:
+ yield output(value)
+ yield output(str(multiplier) + suffix)
+ else: # int
+ before = value // 1000 * 1000
+ residual = value % 1000
+ value = before + residual * multiplier
+ yield output(str(value) + suffix)
+ value = None
+ elif current in self.preceding_prefixers:
+ # apply prefix (positive, minus, etc.) if it precedes a number
+ if value is not None:
+ yield output(value)
+
+ if next in self.words or next_is_numeric:
+ prefix = self.preceding_prefixers[current]
+ else:
+ yield output(current)
+ elif current in self.following_prefixers:
+ # apply prefix (dollars, cents, etc.) only after a number
+ if value is not None:
+ prefix = self.following_prefixers[current]
+ yield output(value)
+ else:
+ yield output(current)
+ elif current in self.suffixers:
+ # apply suffix symbols (percent -> '%')
+ if value is not None:
+ suffix = self.suffixers[current]
+ if isinstance(suffix, dict):
+ if next in suffix:
+ yield output(str(value) + suffix[next])
+ skip = True
+ else:
+ yield output(value)
+ yield output(current)
+ else:
+ yield output(str(value) + suffix)
+ else:
+ yield output(current)
+ elif current in self.specials:
+ if next not in self.words and not next_is_numeric:
+ # apply special handling only if the next word can be numeric
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current == "and":
+ # ignore "and" after hundreds, thousands, etc.
+ if prev not in self.multipliers:
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current == "double" or current == "triple":
+ if next in self.ones or next in self.zeros:
+ repeats = 2 if current == "double" else 3
+ ones = self.ones.get(next, 0)
+ value = str(value or "") + str(ones) * repeats
+ skip = True
+ else:
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current == "point":
+ if next in self.decimals or next_is_numeric:
+ value = str(value or "") + "."
+ else:
+ # should all have been covered at this point
+ raise ValueError(f"Unexpected token: {current}")
+ else:
+ # all should have been covered at this point
+ raise ValueError(f"Unexpected token: {current}")
+
+ if value is not None:
+ yield output(value)
+
+ def preprocess(self, s: str):
+ # replace " and a half" with " point five"
+ results = []
+
+ segments = re.split(r"\band\s+a\s+half\b", s)
+ for i, segment in enumerate(segments):
+ if len(segment.strip()) == 0:
+ continue
+ if i == len(segments) - 1:
+ results.append(segment)
+ else:
+ results.append(segment)
+ last_word = segment.rsplit(maxsplit=2)[-1]
+ if last_word in self.decimals or last_word in self.multipliers:
+ results.append("point five")
+ else:
+ results.append("and a half")
+
+ s = " ".join(results)
+
+ # put a space at number/letter boundary
+ s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
+ s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
+
+ # but remove spaces which could be a suffix
+ s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
+
+ return s
+
+ def postprocess(self, s: str):
+ def combine_cents(m: Match):
+ try:
+ currency = m.group(1)
+ integer = m.group(2)
+ cents = int(m.group(3))
+ return f"{currency}{integer}.{cents:02d}"
+ except ValueError:
+ return m.string
+
+ def extract_cents(m: Match):
+ try:
+ return f"ยข{int(m.group(1))}"
+ except ValueError:
+ return m.string
+
+ # apply currency postprocessing; "$2 and ยข7" -> "$2.07"
+ s = re.sub(r"([โฌยฃ$])([0-9]+) (?:and )?ยข([0-9]{1,2})\b", combine_cents, s)
+ s = re.sub(r"[โฌยฃ$]0.([0-9]{1,2})\b", extract_cents, s)
+
+ # write "one(s)" instead of "1(s)", just for the readability
+ s = re.sub(r"\b1(s?)\b", r"one\1", s)
+
+ return s
+
+ def __call__(self, s: str):
+ s = self.preprocess(s)
+ s = " ".join(word for word in self.process_words(s.split()) if word is not None)
+ s = self.postprocess(s)
+
+ return s
+
+
+class EnglishSpellingNormalizer:
+ """
+ Applies British-American spelling mappings as listed in [1].
+
+ [1] https://www.tysto.com/uk-us-spelling-list.html
+ """
+
+ def __init__(self):
+ mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
+ self.mapping = json.load(open(mapping_path))
+
+ def __call__(self, s: str):
+ return " ".join(self.mapping.get(word, word) for word in s.split())
+
+
+class EnglishTextNormalizer:
+ def __init__(self):
+ self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
+ self.replacers = {
+ # common contractions
+ r"\bwon't\b": "will not",
+ r"\bcan't\b": "can not",
+ r"\blet's\b": "let us",
+ r"\bain't\b": "aint",
+ r"\by'all\b": "you all",
+ r"\bwanna\b": "want to",
+ r"\bgotta\b": "got to",
+ r"\bgonna\b": "going to",
+ r"\bi'ma\b": "i am going to",
+ r"\bimma\b": "i am going to",
+ r"\bwoulda\b": "would have",
+ r"\bcoulda\b": "could have",
+ r"\bshoulda\b": "should have",
+ r"\bma'am\b": "madam",
+ # contractions in titles/prefixes
+ r"\bmr\b": "mister ",
+ r"\bmrs\b": "missus ",
+ r"\bst\b": "saint ",
+ r"\bdr\b": "doctor ",
+ r"\bprof\b": "professor ",
+ r"\bcapt\b": "captain ",
+ r"\bgov\b": "governor ",
+ r"\bald\b": "alderman ",
+ r"\bgen\b": "general ",
+ r"\bsen\b": "senator ",
+ r"\brep\b": "representative ",
+ r"\bpres\b": "president ",
+ r"\brev\b": "reverend ",
+ r"\bhon\b": "honorable ",
+ r"\basst\b": "assistant ",
+ r"\bassoc\b": "associate ",
+ r"\blt\b": "lieutenant ",
+ r"\bcol\b": "colonel ",
+ r"\bjr\b": "junior ",
+ r"\bsr\b": "senior ",
+ r"\besq\b": "esquire ",
+ # prefect tenses, ideally it should be any past participles, but it's harder..
+ r"'d been\b": " had been",
+ r"'s been\b": " has been",
+ r"'d gone\b": " had gone",
+ r"'s gone\b": " has gone",
+ r"'d done\b": " had done", # "'s done" is ambiguous
+ r"'s got\b": " has got",
+ # general contractions
+ r"n't\b": " not",
+ r"'re\b": " are",
+ r"'s\b": " is",
+ r"'d\b": " would",
+ r"'ll\b": " will",
+ r"'t\b": " not",
+ r"'ve\b": " have",
+ r"'m\b": " am",
+ }
+ self.standardize_numbers = EnglishNumberNormalizer()
+ self.standardize_spellings = EnglishSpellingNormalizer()
+
+ def __call__(self, s: str):
+ s = s.lower()
+
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
+ s = re.sub(self.ignore_patterns, "", s)
+ s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
+
+ for pattern, replacement in self.replacers.items():
+ s = re.sub(pattern, replacement, s)
+
+ s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
+ s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
+ s = remove_symbols_and_diacritics(s, keep=".%$ยขโฌยฃ") # keep numeric symbols
+
+ s = self.standardize_numbers(s)
+ s = self.standardize_spellings(s)
+
+ # now remove prefix/suffix symbols that are not preceded/followed by numbers
+ s = re.sub(r"[.$ยขโฌยฃ]([^0-9])", r" \1", s)
+ s = re.sub(r"([^0-9])%", r"\1 ", s)
+
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
+
+ return s
diff --git a/lmms_eval/tasks/librispeech/cn_tn.py b/lmms_eval/tasks/librispeech/cn_tn.py
new file mode 100644
index 00000000..18128d17
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/cn_tn.py
@@ -0,0 +1,1174 @@
+#!/usr/bin/env python3
+# coding=utf-8
+# copied from https://github.com/speechio/chinese_text_normalization/blob/master/python/cn_tn.py
+# Authors:
+# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
+# 2019.9 - 2022 Jiayu DU
+#
+# requirements:
+# - python 3.X
+# notes: python 2.X WILL fail or produce misleading results
+
+import argparse
+import csv
+import os
+import re
+import string
+import sys
+
+# ================================================================================ #
+# basic constant
+# ================================================================================ #
+CHINESE_DIGIS = "้ถไธไบไธๅไบๅ
ญไธๅ
ซไน"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "้ถๅฃน่ดฐๅ่ไผ้ๆๆ็"
+BIG_CHINESE_DIGIS_TRADITIONAL = "้ถๅฃน่ฒณๅ่ไผ้ธๆๆ็"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "ๅ็พๅไธ"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "ๆพไฝฐไป่ฌ"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "ไบฟๅ
ไบฌๅ็งญ็ฉฐๆฒๆถงๆญฃ่ฝฝ"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "ๅๅ
ไบฌๅ็งญ็ฉฐๆบๆพๆญฃ่ผ"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "ๅ็พๅไธ"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "ๆพไฝฐไป่ฌ"
+
+ZERO_ALT = "ใ"
+ONE_ALT = "ๅนบ"
+TWO_ALTS = ["ไธค", "ๅ
ฉ"]
+
+POSITIVE = ["ๆญฃ", "ๆญฃ"]
+NEGATIVE = ["่ด", "่ฒ "]
+POINT = ["็น", "้ป"]
+# PLUS = [u'ๅ ', u'ๅ ']
+# SIL = [u'ๆ ', u'ๆง']
+
+FILLER_CHARS = ["ๅ", "ๅ"]
+
+ER_WHITELIST = "(ๅฟๅฅณ|ๅฟๅญ|ๅฟๅญ|ๅฅณๅฟ|ๅฟๅชณ|ๅฆปๅฟ|" "่ๅฟ|ๅฉดๅฟ|ๆฐ็ๅฟ|ๅฉดๅนผๅฟ|ๅนผๅฟ|ๅฐๅฟ|ๅฐๅฟ|ๅฟๆญ|ๅฟ็ซฅ|ๅฟ็ง|ๆๅฟๆ|ๅญคๅฟ|" "ๅฟๆ|ๅฟๅ|ๅฐๅฟๅบ|้นฟๅฟๅฒ|ๆญฃๅฟๅ
ซ็ป|ๅๅฟ้ๅฝ|็ๅฟ่ฒๅฅณ|ๆๅฟๅธฆๅฅณ|ๅ
ปๅฟ้ฒ่|็ดๅฟๅๅฅณ|" "ไฝณๅฟไฝณๅฆ|ๅฟๆๅ
ฝๆฐ|ๅฟๆ ๅธธ็ถ|ๅฟไธๅซๆฏไธ|ๅฟ่กๅ้ๆฏๆ
ๅฟง|ๅฟๅคงไธ็ฑ็ท|่ไนๅฟ)"
+ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST)
+
+# ไธญๆๆฐๅญ็ณป็ป็ฑปๅ
+NUMBERING_TYPES = ["low", "mid", "high"]
+
+CURRENCY_NAMES = "(ไบบๆฐๅธ|็พๅ
|ๆฅๅ
|่ฑ้|ๆฌงๅ
|้ฉฌๅ
|ๆณ้|ๅ ๆฟๅคงๅ
|ๆพณๅ
|ๆธฏๅธ|ๅ
ไปค|่ฌๅ
ฐ้ฉฌๅ
|็ฑๅฐๅ
ฐ้|" "้ๆ|่ทๅ
ฐ็พ|ๅๆฏๅบๅค|ๆฏๅกๅก|ๅฐๅฐผ็พ|ๆๅ็น|ๆฐ่ฅฟๅ
ฐๅ
|ๆฏ็ดข|ๅขๅธ|ๆฐๅ ๅกๅ
|้ฉๅ
|ๆณฐ้ข)"
+CURRENCY_UNITS = "((ไบฟ|ๅไธ|็พไธ|ไธ|ๅ|็พ)|(ไบฟ|ๅไธ|็พไธ|ไธ|ๅ|็พ|)ๅ
|(ไบฟ|ๅไธ|็พไธ|ไธ|ๅ|็พ|)ๅ|่ง|ๆฏ|ๅ)"
+COM_QUANTIFIERS = (
+ "(ๅน|ๅผ |ๅบง|ๅ|ๅบ|ๅฐพ|ๆก|ไธช|้ฆ|้|้ต|็ฝ|็ฎ|้กถ|ไธ|ๆฃต|ๅช|ๆฏ|่ขญ|่พ|ๆ|ๆ
|้ข|ๅฃณ|็ช |ๆฒ|ๅข|็พค|่
|"
+ "็ ฃ|ๅบง|ๅฎข|่ดฏ|ๆ|ๆ|ๅ|ไปค|ๆ|ๆ|็ฝ|ๅก|ๅฑฑ|ๅฒญ|ๆฑ|ๆบช|้|้|ๅ|ๅ|ๅฏน|ๅบ|ๅฃ|ๅคด|่|ๆฟ|่ทณ|ๆ|ไปถ|่ดด|"
+ "้|็บฟ|็ฎก|ๅ|ไฝ|่บซ|ๅ |่ฏพ|ๆฌ|้กต|ๅฎถ|ๆท|ๅฑ|ไธ|ๆฏซ|ๅ|ๅ|้ฑ|ไธค|ๆค|ๆ
|้ข|็ณ|้ง|้ฑ|ๅฟฝ|(ๅ|ๆฏซ|ๅพฎ)ๅ
|"
+ "ๆฏซ|ๅ|ๅ|ๅฏธ|ๅฐบ|ไธ|้|ๅฏป|ๅธธ|้บ|็จ|(ๅ|ๅ|ๅ|ๆฏซ|ๅพฎ)็ฑณ|ๆฎ|ๅบ|ๅ|ๅ|ๆ|็ณ|็|็ข|็ข|ๅ |ๆกถ|็ฌผ|็|"
+ "็|ๆฏ|้|ๆ|้
|็ฐ|็ฏฎ|็|ๆกถ|็ฝ|็ถ|ๅฃถ|ๅฎ|็|็ฎฉ|็ฎฑ|็
ฒ|ๅ|่ข|้ต|ๅนด|ๆ|ๆฅ|ๅญฃ|ๅป|ๆถ|ๅจ|ๅคฉ|็ง|ๅ|ๆฌ|"
+ "็บช|ๅฒ|ไธ|ๆด|ๅค|ๆฅ|ๅค|็ง|ๅฌ|ไปฃ|ไผ|่พ|ไธธ|ๆณก|็ฒ|้ข|ๅนข|ๅ |ๆก|ๆ น|ๆฏ|้|้ข|็|ๅผ |้ข|ๅ)"
+)
+
+
+# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
+CN_PUNCS_STOP = "๏ผ๏ผ๏ฝกใ"
+CN_PUNCS_NONSTOP = "๏ผ๏ผ๏ผ๏ผ
๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ๏ผ ๏ผป๏ผผ๏ผฝ๏ผพ๏ผฟ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ ๏ฝข๏ฝฃ๏ฝคใใใใใใใใใใใใใใใใใใใใใใใฐใพใฟโโโโโโโโโโฆโง๏นยทใใ-"
+CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
+
+PUNCS = CN_PUNCS + string.punctuation
+PUNCS_TRANSFORM = str.maketrans(PUNCS, " " * len(PUNCS), "") # replace puncs with space
+
+
+# https://zh.wikipedia.org/wiki/ๅ
จ่กๅๅ่ก
+QJ2BJ = {
+ "ใ": " ",
+ "๏ผ": "!",
+ "๏ผ": '"',
+ "๏ผ": "#",
+ "๏ผ": "$",
+ "๏ผ
": "%",
+ "๏ผ": "&",
+ "๏ผ": "'",
+ "๏ผ": "(",
+ "๏ผ": ")",
+ "๏ผ": "*",
+ "๏ผ": "+",
+ "๏ผ": ",",
+ "๏ผ": "-",
+ "๏ผ": ".",
+ "๏ผ": "/",
+ "๏ผ": "0",
+ "๏ผ": "1",
+ "๏ผ": "2",
+ "๏ผ": "3",
+ "๏ผ": "4",
+ "๏ผ": "5",
+ "๏ผ": "6",
+ "๏ผ": "7",
+ "๏ผ": "8",
+ "๏ผ": "9",
+ "๏ผ": ":",
+ "๏ผ": ";",
+ "๏ผ": "<",
+ "๏ผ": "=",
+ "๏ผ": ">",
+ "๏ผ": "?",
+ "๏ผ ": "@",
+ "๏ผก": "A",
+ "๏ผข": "B",
+ "๏ผฃ": "C",
+ "๏ผค": "D",
+ "๏ผฅ": "E",
+ "๏ผฆ": "F",
+ "๏ผง": "G",
+ "๏ผจ": "H",
+ "๏ผฉ": "I",
+ "๏ผช": "J",
+ "๏ผซ": "K",
+ "๏ผฌ": "L",
+ "๏ผญ": "M",
+ "๏ผฎ": "N",
+ "๏ผฏ": "O",
+ "๏ผฐ": "P",
+ "๏ผฑ": "Q",
+ "๏ผฒ": "R",
+ "๏ผณ": "S",
+ "๏ผด": "T",
+ "๏ผต": "U",
+ "๏ผถ": "V",
+ "๏ผท": "W",
+ "๏ผธ": "X",
+ "๏ผน": "Y",
+ "๏ผบ": "Z",
+ "๏ผป": "[",
+ "๏ผผ": "\\",
+ "๏ผฝ": "]",
+ "๏ผพ": "^",
+ "๏ผฟ": "_",
+ "๏ฝ": "`",
+ "๏ฝ": "a",
+ "๏ฝ": "b",
+ "๏ฝ": "c",
+ "๏ฝ": "d",
+ "๏ฝ
": "e",
+ "๏ฝ": "f",
+ "๏ฝ": "g",
+ "๏ฝ": "h",
+ "๏ฝ": "i",
+ "๏ฝ": "j",
+ "๏ฝ": "k",
+ "๏ฝ": "l",
+ "๏ฝ": "m",
+ "๏ฝ": "n",
+ "๏ฝ": "o",
+ "๏ฝ": "p",
+ "๏ฝ": "q",
+ "๏ฝ": "r",
+ "๏ฝ": "s",
+ "๏ฝ": "t",
+ "๏ฝ": "u",
+ "๏ฝ": "v",
+ "๏ฝ": "w",
+ "๏ฝ": "x",
+ "๏ฝ": "y",
+ "๏ฝ": "z",
+ "๏ฝ": "{",
+ "๏ฝ": "|",
+ "๏ฝ": "}",
+ "๏ฝ": "~",
+}
+QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "")
+
+
+# 2013 China National Standard: https://zh.wikipedia.org/wiki/้็จ่ง่ๆฑๅญ่กจ, raw resources:
+# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total
+CN_CHARS_COMMON = (
+ "ไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธไธขไธคไธฅไธงไธชไธซไธญไธฐไธฒไธดไธธไธนไธบไธปไธฝไธพ"
+ "ไนไนไน
ไนไนไนไนไนไนไนไนไนไนไนไนไนไนไนไนไนไนไน ไนกไนฆไนฉไนฐไนฑไนณไนธไนพไบไบไบไบไบไบไบไบไบไบ"
+ "ไบไบไบไบไบไบไบไบกไบขไบคไบฅไบฆไบงไบจไบฉไบซไบฌไบญไบฎไบฒไบณไบตไบถไบธไบนไบบไบฟไปไปไปไปไปไป
ไปไปไปไปไปไปไป"
+ "ไปไปไปไปไปไปไปไปไปไปไปไปกไปฃไปคไปฅไปจไปชไปซไปฌไปฐไปฒไปณไปตไปถไปทไปปไปฝไปฟไผไผไผไผไผไผไผไผไผไผไผไผ"
+ "ไผไผไผไผไผไผ ไผขไผฃไผคไผฅไผฆไผงไผชไผซไผญไผฏไผฐไผฒไผดไผถไผธไผบไผผไผฝไผพไฝไฝไฝไฝไฝไฝไฝไฝไฝไฝไฝไฝไฝไฝไฝ"
+ "ไฝไฝไฝไฝไฝไฝ ไฝฃไฝคไฝฅไฝฉไฝฌไฝฏไฝฐไฝณไฝดไฝถไฝธไฝบไฝปไฝผไฝฝไฝพไฝฟไพไพไพไพไพไพไพไพไพไพไพไพไพไพไพไพ ไพฃ"
+ "ไพฅไพฆไพงไพจไพฉไพชไพฌไพฎไพฏไพดไพตไพนไพฟไฟไฟไฟ
ไฟไฟไฟไฟไฟไฟไฟไฟไฟไฟไฟไฟไฟไฟไฟกไฟฃไฟฆไฟจไฟฉไฟชไฟซไฟญไฟฎไฟฏ"
+ "ไฟฑไฟณไฟตไฟถไฟธไฟบไฟพๅๅๅๅๅๅๅๅๅๅๅๅๅๅกๅฅๅฆๅงๅจๅฉๅชๅฌๅญๅฎๅดๅบๅปๅผๅพๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅกๅฅๅฌๅญๅฐๅฒๅถๅทๅปๅพๅฟๅๅๅ
ๅๅๅๅๅๅฃๅฅๅงๅจๅฉๅฌๅฒๅบๅปๅๅๅๅๅๅ"
+ "ๅฆๅงๅฌๅญๅฎๅฐๅณๅตๅปๅๅๅๅๅกๅฆๅณๅดๅฟๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ๅ
ขๅ
ฅๅ
จๅ
ซๅ
ฌๅ
ญ"
+ "ๅ
ฎๅ
ฐๅ
ฑๅ
ณๅ
ดๅ
ตๅ
ถๅ
ทๅ
ธๅ
นๅ
ปๅ
ผๅ
ฝๅๅๅ
ๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅขๅคๅฅๅฌๅฎๅฏๅฐๅฑๅฒๅณๅต"
+ "ๅถๅทๅปๅผๅฝๅๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅกๅคๅซๅญๅฏๅฐๅณๅถๅธๅนๅบๅปๅผๅฝๅฟๅๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅคๅจๅฉๅซๅฌๅญๅฎๅฐๅณๅถๅทๅธๅนๅบๅปๅฝๅฟๅๅๅๅๅ
ๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅๅกๅฅๅงๅฉๅชๅฏๅฒๅฝๅฟๅๅๅๅๅๅๅๅๅๅๅ ๅกๅขๅฃๅจๅฉๅชๅซๅฌๅญๅฑๅฒๅณๅผ"
+ "ๅพๅฟๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅคๅฐๅบๅพๅฟๅๅ
ๅๅๅๅๅๅๅๅๅๅๅๅ ๅกๅฃๅฆๅชๅฎๅน"
+ "ๅบๅปๅผๅพๅฟๅๅๅ
ๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅกๅขๅฃๅคๅฆๅงๅซๅฌๅฎๅฏๅฐๅฑๅณๅดๅต"
+ "ๅทๅธๅบๅฟๅๅๅ
ๅๅๅๅๅๅๅๅๅๅๅๅขๅฃๅฅๅฆๅจๅฉๅฎๅปๅพๅฟๅๅๅๅๅๅๅๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅๅๅ ๅฃๅคๅฅๅฆๅจๅฉๅชๅซๅฌๅญๅฎๅฏๅฐๅฑๅฒๅณๅตๅถๅทๅธๅนๅปๅผๅฝๅๅๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅกๅฃๅฆๅงๅจๅฉๅซๅฌๅญๅฎๅฏๅฑๅฒๅดๅตๅธๅนๅปๅผๅฝๅพๅๅๅ"
+ "ๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅขๅฃๅคๅฆๅจๅฑๅฒๅณๅตๅถๅทๅธๅปๅผๅฝๅๅๅๅๅๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅๅๅๅๅกๅฃๅคๅฅๅฆๅงๅจๅฉๅชๅซๅฌๅฏๅฑๅณๅดๅธๅบๅปๅฝๅฟๅๅๅๅๅๅๅๅๅๅ"
+ "ๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅขๅฅๅฆๅงๅจๅฉๅชๅญๅฎๅฑๅฒๅณๅบๅผๅฝๅฟๅๅๅๅๅๅๅๅๅ"
+ "ๅๅ ๅขๅฃๅคๅงๅชๅฌๅฎๅฏๅฐๅฑๅณๅตๅทๅผๅพๅฟๅๅๅๅๅๅๅๅๅๅๅกๅคๅฅๅฆๅงๅชๅซๅฌๅญๅฎๅฐๅด"
+ "ๅตๅถๅทๅธๅปๅผๅพๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅคๅงๅฑๅณๅตๅทๅนๅปๅฝๅพๅๅ
ๅๅ"
+ "ๅๅๅๅๅๅๅๅๅๅๅๅกๅฃๅคๅฅๅฆๅจๅชๅซๅฌๅฏๅฒๅณๅตๅทๅฝๅพๅๅๅๅๅๅๅๅๅๅๅๅๅก"
+ "ๅฃๅคๅงๅฌๅญๅฑๅฒๅดๅถๅนๅปๅฟๅๅๅๅๅๅๅๅๅๅๅๅขๅคๅจๅฉๅชๅซๅฌๅฑๅถๅปๅผๅๅ
ๅๅๅๅ"
+ "ๅๅฃๅญๅฏๅทๅผๅๅๅๅๅๅๅ ๅกๅขๅคๅซๅญๅฐๅฑๅดๅตๅทๅนๅบๅฝๅพๅฟๅๅๅๅๅๅๅๅๅๅๅๅข"
+ "ๅฃๅจๅฉๅชๅซๅฌๅญๅฎๅฏๅฐๅฒๅณๅนๅบๅปๅพๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅกๅคๅฅ"
+ "ๅฆๅจๅฉๅชๅซๅฌๅญๅฏๅฐๅณๅทๅปๅผๅฝๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅกๅขๅฃๅคๅฆๅงๅฉ"
+ "ๅซๅญๅฎๅฏๅฑๅฒๅดๅตๅธๅบๅพๅฟๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅๅ ๅคๅชๅซๅญๅฏๅดๅตๅธๅนๅบ"
+ "ๅผๅฝๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ ๅ กๅ คๅ งๅ จๅ ชๅ ฐๅ ฒๅ ตๅ ผๅ ฝๅ พๅกๅก
ๅกๅกๅกๅกๅกๅกๅกๅกๅกฅๅกซ"
+ "ๅกฌๅกฑๅกพๅขๅขๅขๅข
ๅขๅขๅขๅขๅขๅขๅขๅขๅขๅขๅขๅขกๅขฃๅขฆๅขจๅขฉๅขผๅฃๅฃ
ๅฃๅฃๅฃคๅฃซๅฃฌๅฃฎๅฃฐๅฃณๅฃถๅฃธๅฃนๅคๅคๅค"
+ "ๅคๅคๅคๅคๅคๅคๅคๅคๅคๅคคๅคฅๅคงๅคฉๅคชๅคซๅคฌๅคญๅคฎๅคฏๅคฑๅคดๅคทๅคธๅคนๅคบๅคผๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅๅฅ"
+ "ๅฅๅฅๅฅๅฅ ๅฅกๅฅขๅฅฅๅฅญๅฅณๅฅดๅฅถๅฅธๅฅนๅฅฝๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆๅฆฃๅฆคๅฆฅๅฆงๅฆจๅฆฉๅฆชๅฆซๅฆญๅฆฎ"
+ "ๅฆฏๅฆฒๅฆนๅฆปๅฆพๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงๅงฃๅงคๅงฅๅงจๅงฌๅงฎๅงฑๅงถๅงนๅงปๅงฝๅงฟๅจๅจๅจๅจๅจ
ๅจๅจๅจ"
+ "ๅจๅจๅจๅจๅจๅจๅจๅจ ๅจฃๅจฅๅจฉๅจฑๅจฒๅจดๅจตๅจถๅจผๅฉๅฉๅฉๅฉๅฉๅฉๅฉๅฉๅฉๅฉๅฉ ๅฉขๅฉคๅฉงๅฉชๅฉซๅฉณๅฉดๅฉตๅฉถๅฉทๅฉบๅฉป"
+ "ๅฉผๅฉฟๅชๅชๅชๅชๅชๅชๅชๅชๅชๅชชๅชญๅชฑๅชฒๅชณๅชตๅชธๅชพๅซๅซๅซๅซๅซๅซๅซๅซๅซๅซๅซๅซๅซ ๅซกๅซฃๅซฆๅซฉๅซชๅซซๅซญๅซฑ"
+ "ๅซฝๅฌๅฌๅฌๅฌๅฌฅๅฌฌๅฌดๅฌทๅฌฟๅญๅญ
ๅญๅญๅญๅญๅญๅญๅญๅญๅญๅญๅญๅญๅญๅญๅญขๅญฃๅญคๅญฅๅญฆๅญฉๅญชๅญฌๅญฐๅญฑๅญณๅญตๅญบๅญฝ"
+ "ๅฎๅฎๅฎๅฎ
ๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎๅฎ ๅฎกๅฎขๅฎฃๅฎคๅฎฅๅฎฆๅฎงๅฎชๅฎซๅฎฌๅฎฐๅฎณๅฎดๅฎตๅฎถๅฎธๅฎนๅฎฝๅฎพ"
+ "ๅฎฟๅฏๅฏๅฏๅฏ
ๅฏๅฏๅฏๅฏๅฏๅฏๅฏๅฏๅฏๅฏกๅฏคๅฏฅๅฏจๅฏฎๅฏฐๅฏธๅฏนๅฏบๅฏปๅฏผๅฏฟๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐๅฐ"
+ "ๅฐขๅฐคๅฐฅๅฐงๅฐจๅฐชๅฐฌๅฐฑๅฐดๅฐธๅฐนๅฐบๅฐปๅฐผๅฐฝๅฐพๅฐฟๅฑๅฑๅฑๅฑๅฑ
ๅฑๅฑๅฑๅฑๅฑๅฑๅฑๅฑๅฑๅฑๅฑๅฑ ๅฑกๅฑฃๅฑฅๅฑฆๅฑฏๅฑฑ"
+ "ๅฑนๅฑบๅฑผๅฑพๅฑฟๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒๅฒ ๅฒขๅฒฃๅฒจๅฒฉๅฒซๅฒฌๅฒญๅฒฑๅฒณๅฒตๅฒทๅฒธๅฒฝๅฒฟๅณๅณๅณ"
+ "ๅณๅณๅณๅณๅณๅณๅณๅณกๅณฃๅณคๅณฅๅณฆๅณงๅณจๅณชๅณญๅณฐๅณฑๅณปๅณฟๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดๅดกๅดคๅดฆๅดง"
+ "ๅดฉๅดญๅดฎๅดดๅดถๅดฝๅดพๅดฟๅตๅต
ๅตๅตๅตๅตๅตๅตๅตๅตๅตๅตๅตฉๅตซๅตฌๅตฏๅตฒๅตดๅถๅถ
ๅถๅถๅถๅถๅถๅถๅถฆๅถฒๅถทๅท
ๅทๅท"
+ "ๅทๅทๅทๅทกๅทขๅทฅๅทฆๅทงๅทจๅทฉๅทซๅทฎๅทฏๅทฑๅทฒๅทณๅทดๅททๅทฝๅทพๅธๅธๅธๅธ
ๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธๅธก"
+ "ๅธฆๅธงๅธจๅธญๅธฎๅธฑๅธทๅธธๅธปๅธผๅธฝๅนๅนๅน
ๅนๅนๅนๅนๅนๅนๅนกๅนขๅนชๅนฒๅนณๅนดๅนถๅนธๅนบๅนปๅนผๅนฝๅนฟๅบๅบๅบๅบๅบๅบๅบ"
+ "ๅบๅบๅบๅบๅบๅบๅบๅบๅบๅบๅบๅบ ๅบคๅบฅๅบฆๅบงๅบญๅบฑๅบณๅบตๅบถๅบทๅบธๅบนๅบผๅบพๅปๅปๅปๅปๅปๅปๅปๅปๅปๅปๅปจๅปชๅปถๅปท"
+ "ๅปบๅปฟๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผๅผ ๅผขๅผฅๅผฆๅผงๅผจๅผฉๅผญๅผฏๅผฑๅผถๅผธๅผนๅผบๅผผๅฝๅฝๅฝๅฝๅฝ"
+ "ๅฝๅฝๅฝๅฝๅฝขๅฝคๅฝฆๅฝงๅฝฉๅฝชๅฝฌๅฝญๅฝฐๅฝฑๅฝณๅฝทๅฝนๅฝปๅฝผๅพๅพๅพๅพๅพ
ๅพๅพๅพๅพๅพๅพๅพๅพๅพๅพๅพๅพๅพๅพกๅพจๅพช"
+ "ๅพญๅพฎๅพตๅพทๅพผๅพฝๅฟๅฟ
ๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟๅฟ ๅฟกๅฟคๅฟงๅฟชๅฟซๅฟญๅฟฎๅฟฑๅฟณๅฟตๅฟธๅฟบๅฟปๅฟฝๅฟพๅฟฟๆ"
+ "ๆๆๆๆๆ
ๆๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆกๆฅๆฆๆงๆจๆฉๆชๆซๆฏๆตๆปๆผๆฟๆๆๆๆๆๆๆๆ"
+ "ๆๆๆๆๆๆขๆฃๆคๆงๆจๆฉๆชๆซๆฌๆญๆฏๆฐๆณๆถๆธๆนๆบๆปๆผๆฝๆฟๆๆๆๆๆๆๆๆๆๆๆๆๆๆ"
+ "ๆ ๆขๆฃๆฆๆจๆซๆฌๆญๆฏๆฐๆฑๆฒๆดๆธๆปๆผๆ
ๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆฆๆงๆจๆฉๆซๆฌๆญ"
+ "ๆฎๆฏๆฐๆณๆดๆถๆนๆบๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆฃๆคๆฆๆงๆซๆญๆฟๆ
ๆ
ๆ
ๆ
ๆ
ๆ
ๆ
ๆ
ๆ
ขๆ
ฅ"
+ "ๆ
งๆ
จๆ
ฌๆ
ญๆ
ฐๆ
ตๆ
ทๆๆๆๆๆๆงๆจๆฉๆฌๆญๆทๆบๆพๆๆๆๆๆๆๆๆฆๆตๆฟๆๆๆๆๆๆๆๆๆๆ"
+ "ๆๆๆๆๆๆๆๆๆกๆขๆฃๆคๆฅๆชๆฌๆญๆฎๆณๆดๆทๆฝๆพๆฟๆๆๆๆๆ
ๆๆๆๆๆๆๆๆๆๆๆๆ"
+ "ๆๆๆๆฃๆฆๆงๆฉๆชๆซๆฌๆญๆฎๆฏๆฐๆณๆถๆนๆบๆผๆฝๆพๆฟๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆก"
+ "ๆขๆคๆฅๆจๆซๆฌๆฑๆตๆนๆปๆผๆฝๆฟๆๆๆๆ
ๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆขๆฃๆคๆฅ"
+ "ๆฆๆงๆจๆฉๆฌๆญๆฎๆฏๆฑๆณๆดๆถๆทๆผๆฝๆพๆฟๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆกๆฃๆคๆฅๆฆๆจๆชๆซ"
+ "ๆฏๆฒๆนๆบๆฝๆๆๆ
ๆๆๆๆๆๆๆๆๆๆๆๆกๆขๆฃๆงๆฉๆญๆฎๆฏๆถๆทๆบๆปๆฝๆๆๆๆๆๆๆๆ"
+ "ๆๆๆๆๆๆๆ ๆขๆฃๆฅๆงๆจๆฉๆชๆฌๆญๆฎๆฐๆณๆดๆทๆธๆบๆผๆพๆๆๆๆๆๆๆๆๆๆ ๆกๆฃๆฉๆชๆญ"
+ "ๆณๆดๆถๆธๆฝๆฟๆๆๆๆ
ๆๆๆๆๆๆๆๆๆๆๆ ๆกๆฆๆชๆฌๆญๆดๆบๆฝๆๆๆ
ๆๆๆๆๆๆๆๆ"
+ "ๆๆๆงๆฉๆญๆดๆธๆนๆฝๆๆๆ
ๆๆๆๆๆๆๆๆคๆฉๆฌๆญๆฎๆฐๆตๆทๆธๆบๆผๆๆๆ
ๆๆๆๆๆๆๆข"
+ "ๆคๆฆๆฟๆๆๆๆๆฅๆซๆฎๆฏๆถๆธๆนๆปๆฝๆพๆฟๆ
ๆๆๆๆๆๆๆๆๆๆๆๆๆขๆฃๆฆๆฉๆซๆฌๆฐๆฒๆด"
+ "ๆทๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆกๆคๆฅๆงๆฉๆซๆญๆฏๆฐๆถๆนๆผๆฝๆๆๆๆ
ๆๆๆๆๆๆๆๆๆ"
+ "ๆๆ ๆขๆฅๆฆๆงๆจๆฉๆฌๆญๆฎๆฏๆฐๆฑๆดๆตๆถๆทๆธๆบๆปๆฟๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆ"
+ "ๆๆ ๆกๆฃๆคๆฅๆงๆจๆชๆซๆญๆฏๆฑๆณๆดๆตๆถๆบๆผๆฝๆพๆๆๆ
ๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆก"
+ "ๆขๆคๆฆๆจๆชๆซๆฎๆฏๆฐๆฑๆดๆถๆทๆบๆพๆๆๆ
ๆๆๆๆๆๆๆๆงๆจๆฎๆฒๆดๆตๆถๆนๆพๆฟๆๆๆๆๆ"
+ "ๆๆฆๆฉๆฐๆฒๆณๆดๆทๆนๆผๆพๆฟๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆฆๆจๆชๆซๆฌๆญๆฏๆฑๆณๆดๆตๆธๆบๆฝ"
+ "ๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆๆ ๆกๆฅๆงๆจๆฉๆชๆญๆฏๆฐๆฒๆณๆตๆทๆปๆผๆพๆฟๆๆๆ
"
+ "ๆๆๆๆๆๆๆๆๆๆๆๆๆขๆฃๆฅๆงๆจๆชๆซๆญๆฏๆฐๆฒๆณๆตๆถๆทๆธๆนๆๆๆๆๆๆๆๆๆๆๆ"
+ "ๆๆๆๆๆๆๆๆ ๆขๆฅๆฉๆฌๆฏๆฐๆฑๆณๆดๆทๆฝๆฟๆ ๆ
ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ ๆ กๆ ฉ"
+ "ๆ ชๆ ฒๆ ณๆ ดๆ ทๆ ธๆ นๆ ปๆ ผๆ ฝๆ พๆกๆกๆกๆกๆกๆก
ๆกๆกๆกๆกๆกๆกๆกๆกๆกๆกๆกๆก ๆกกๆกขๆกฃๆกคๆกฅๆกฆๆกงๆกจๆกฉๆกซๆกฏ"
+ "ๆกฒๆกดๆกถๆกทๆกนๆขๆขๆข
ๆขๆขๆขๆขๆขๆข ๆขขๆขฃๆขฆๆขงๆขจๆขญๆขฏๆขฐๆขณๆขดๆขตๆขผๆขฝๆขพๆขฟๆฃๆฃๆฃๆฃๆฃๆฃๆฃๆฃๆฃๆฃๆฃ"
+ "ๆฃๆฃ ๆฃฃๆฃคๆฃจๆฃชๆฃซๆฃฌๆฃฎๆฃฐๆฃฑๆฃตๆฃนๆฃบๆฃปๆฃผๆฃฝๆคๆคๆค
ๆคๆคๆคๆคๆคๆคๆคๆคๆคๆค ๆคคๆคชๆคญๆคฐๆคดๆคธๆคนๆคฝๆคฟๆฅ"
+ "ๆฅๆฅๆฅๆฅๆฅๆฅๆฅๆฅ ๆฅฃๆฅฆๆฅฉๆฅชๆฅซๆฅฎๆฅฏๆฅทๆฅธๆฅนๆฅผๆฆๆฆๆฆๆฆ
ๆฆๆฆๆฆๆฆๆฆๆฆๆฆๆฆๆฆๆฆๆฆๆฆงๆฆจๆฆซๆฆญๆฆฐๆฆฑ"
+ "ๆฆดๆฆทๆฆปๆงๆงๆงๆงๆงๆงๆงๆงๆงๆงๆงๆง ๆงญๆงฑๆงฒๆงฝๆงฟๆจๆจๆจๆจๆจกๆจจๆจชๆจฏๆจฑๆจตๆจฝๆจพๆฉๆฉๆฉๆฉๆฉๆฉๆฉๆฉ"
+ "ๆฉกๆฉฅๆฉฆๆฉฑๆฉนๆฉผๆชๆชๆชๆชๆชๆชๆชๆช ๆชฉๆชซๆชฌๆซๆฌๆฌ ๆฌกๆฌขๆฌฃๆฌคๆฌงๆฌฒๆฌธๆฌนๆฌบๆฌปๆฌพๆญๆญ
ๆญๆญๆญๆญๆญๆญขๆญฃ"
+ "ๆญคๆญฅๆญฆๆญงๆญชๆญนๆญปๆญผๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎๆฎกๆฎฃๆฎชๆฎณๆฎดๆฎตๆฎทๆฎฟๆฏๆฏๆฏ
ๆฏๆฏๆฏๆฏๆฏๆฏ"
+ "ๆฏๆฏๆฏๆฏๆฏๆฏๆฏๆฏกๆฏชๆฏซๆฏฏๆฏณๆฏตๆฏนๆฏฝๆฐ
ๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐๆฐกๆฐขๆฐคๆฐฆๆฐงๆฐจๆฐฉๆฐชๆฐฎ"
+ "ๆฐฏๆฐฐๆฐฒๆฐดๆฐธๆฐพๆฐฟๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑๆฑ ๆฑกๆฑคๆฑงๆฑจๆฑฉๆฑชๆฑซๆฑญๆฑฐๆฑฒๆฑดๆฑถๆฑนๆฑฝ"
+ "ๆฑพๆฒๆฒๆฒๆฒๆฒ
ๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒๆฒกๆฒฃๆฒคๆฒฅๆฒฆๆฒงๆฒจๆฒฉๆฒชๆฒซๆฒญๆฒฎๆฒฑๆฒณๆฒธๆฒนๆฒบๆฒปๆฒผๆฒฝ"
+ "ๆฒพๆฒฟๆณๆณๆณๆณ
ๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณๆณ ๆณกๆณขๆณฃๆณฅๆณจๆณชๆณซๆณฎๆณฏๆณฐๆณฑๆณณๆณตๆณทๆณธๆณบๆณปๆณผ"
+ "ๆณฝๆณพๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดๆดขๆดฃๆดฅๆดงๆดจๆดชๆดซๆดญๆดฎๆดฑๆดฒๆดณๆดดๆดตๆดธๆดนๆดบๆดปๆดผๆดฝๆดพๆดฟ"
+ "ๆตๆตๆต
ๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆตๆต ๆตกๆตฃๆตฅๆตฆๆตฉๆตชๆตฌๆตญๆตฎๆตฏๆตฐๆตฒๆตดๆตทๆตธ"
+ "ๆตผๆถๆถๆถ
ๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถๆถ ๆถกๆถขๆถฃๆถคๆถฆๆถงๆถจๆถฉๆถชๆถซๆถฎๆถฏๆถฒๆถดๆถตๆถธๆถฟๆทๆทๆท
"
+ "ๆทๆทๆทๆทๆทๆทๆทๆทๆทๆทๆทๆทๆทๆท ๆทกๆทคๆทฆๆทซๆทฌๆทฎๆทฏๆทฑๆทณๆทดๆททๆทนๆทปๆทผๆธ
ๆธๆธๆธๆธๆธๆธๆธๆธๆธๆธๆธ"
+ "ๆธ ๆธกๆธฃๆธคๆธฅๆธฉๆธซๆธญๆธฏๆธฐๆธฒๆธดๆธธๆธบๆธผๆนๆนๆนๆนๆนๆนๆนๆนๆนๆนๆนๆนๆนๆนๆนฃๆนซๆนฎๆนฒๆนดๆนพๆนฟๆบๆบๆบ
ๆบ"
+ "ๆบๆบๆบๆบๆบๆบๆบๆบๆบๆบๆบ ๆบขๆบฅๆบฆๆบงๆบชๆบฏๆบฑๆบฒๆบดๆบตๆบถๆบทๆบนๆบบๆบปๆบฝๆปๆปๆปๆปๆปๆปๆปๆปๆปๆปๆปๆปๆป"
+ "ๆปๆปๆปๆปๆปๆป ๆปกๆปขๆปคๆปฅๆปฆๆปงๆปจๆปฉๆปชๆปซๆปดๆปนๆผๆผๆผๆผๆผๆผๆผๆผๆผๆผๆผ ๆผคๆผฆๆผฉๆผชๆผซๆผญๆผฏๆผฑๆผณๆผดๆผถ"
+ "ๆผทๆผนๆผปๆผผๆผพๆฝๆฝๆฝๆฝๆฝๆฝๆฝๆฝๆฝๆฝๆฝขๆฝฆๆฝฉๆฝญๆฝฎๆฝฒๆฝดๆฝตๆฝธๆฝบๆฝผๆฝฝๆฝพๆพๆพๆพๆพๆพๆพๆพๆพๆพๆพกๆพฅๆพง"
+ "ๆพชๆพญๆพณๆพดๆพถๆพนๆพผๆพฝๆฟๆฟๆฟๆฟๆฟๆฟๆฟๆฟ ๆฟกๆฟฉๆฟฎๆฟฏ็็็็็็็ฃ็ฑ็ต็น็ผ็็็็็ซ็ญ็ฏ็ฐ็ต"
+ "็ถ็ธ็ผ็พ็ฟ็็
็็็็็็็็็็็็็็็ฃ็ซ็ฌ็ญ็ฎ็ฏ็ฑ็ณ็ท็ธ็น็ป็ผ็ฝ็็็็็"
+ "็็็็็็็็็ ็ค็ฆ็ง็จ็ฉ็ซ็ฌ็ญ็ฏ็ถ็ท็น็บ็ป็ฝ็็็็็็็็็็็็็็็ฆ็ฏ"
+ "็ฐ็ฑ็ถ็
็
็
็
็
็
็
็
็
็
็
็
ค็
ฆ็
ง็
จ็
ฎ็
ฒ็
ณ็
ด็
ธ็
บ็
ฝ็็็็็็็็็็็ ็ฅ็จ็ฌ็ต"
+ "็น็ป็็็็็็็็็ ็ฅ็ง็ฎ็น็็็็็็็จ็ช็ฌ็ฐ็ฑ็ต็ถ็ท็ธ็น็ป็ฝ็ฟ็็็็็็"
+ "็็็็็็็็ก็ข็ค็ฅ็ฆ็ง็ฉ็ฎ็ฏ็ฒ็ต็น็บ็ป็พ็ฟ็็็็็็็็็็็จ็ฌ็ฏ็ฐ็ด็ถ็ท"
+ "็ธ็น็็็็็็็็็็็็็็็ ็ก็จ็ฉ็ฌ็ญ็ฎ็ฏ็ฐ็ฑ็ฒ็ณ็ด็ท็ธ็บ็ป็ผ็็็็็็"
+ "็็็็็็็็ก็ข็ฅ็ฉ็ช็ซ็ฌ็ฎ็ฏ็ฐ็ฑ็ด็ท็น็บ็พ็ฟ็็็็็ ็ฌ็ญ็ฏ็ด็พ็็็็็็"
+ "็็็็็็็็็็็็ ็ก็ข็ค็ฅ็ฆ็ฉ็ซ็ญ็ฎ็ฏ็ฐ็ฑ็ฒ็ณ็ถ็ท็น็บ็ป็ผ็ฟ็็็
็็็็"
+ "็็็็็็็็็็็็็็ ็ข็ฃ็ฅ็ฆ็ง็ฉ็ช็ซ็ญ็ฐ็ฒ็ต็ท็ธ็น็บ็ฝ็็็็
็็็็็"
+ "็็็็็็็็ก็ข็ค็ฅ็ฆ็จ็ช็ซ็ฌ็ญ็ฎ็ฏ็ฐ็ฒ็ณ็ด็ต็ถ็ผ็็็็็็
็็็็็็็็"
+ "็็็็็็็ข็ง็จ็ฌ็ญ็ฐ็ฑ็ณ็ถ็ท็พ็็็็็็็็็็็็็็็ ็ฅ็ง็จ็ฉ็ช็ฌ็ฎ็ฑ"
+ "็ฒ็บ็็็็็็็ ็ข็ฃ็ค็ฆ็ฎ็ฏ็ด็ถ็ท็ป็ฟ็็็็็็็็็็็ก็ฅ็ฆ็จ็ฉ็ช็ซ็ฌ็ญ็ฏ"
+ "็ฐ็ฑ็ฒ็ณ็ต็ท็ธ็บ็ป็พ็็
็็็็็็็็็็็็ค็ฅ็ฆ็ช็ฌ็ฏ็ฒ็ด็ธ็น็ฟ็็็็็็"
+ "็็็็็็็็็ ็ก็ข็ฃ็ค็ฅ็ซ็ฌ็ญ็ฎ็ฏ็ฐ็ฑ็ฒ็ณ็ด็ต็ธ็น็ผ็ฝ็พ็็็็
็็็็็็"
+ "็็็็็็็ข็ฃ็ค็ฆ็ง็จ็ช็ซ็ฐ็ฑ็ด็น็ผ็ฟ็็็็
็็็็็็็็็็็ ็ข็ค็ฅ็ฆ็ฉ"
+ "็ช็ซ็ญ็ฐ็ณ็ด็ต็ธ็ผ็พ็ฟ็็็็็็็็็็ฃ็ซ็ฏ็ธ็ป็ฝ็พ็ฟ็็็็็็็็็็็็"
+ "็็็ค็ฆ็ญ็ฎ็ฑ็ฒ็ด็ฟ็็
็็็็็็็็็็็็็็็็็ฅ็ฆ็ฎ็ฏ็ฑ็ฒ็ด็ท็ธ็น็ผ็พ"
+ "็็็็็็็็็็็็ ็ข็ฆ็จ็ฉ็ฌ็ญ็ฏ็ต็ถ็ท็ธ็บ็ผ็็็็็็็็็็็ก็ข็ฃ็ฅ็ฆ"
+ "็จ็ซ็ฌ็น็ฝ็พ็ฟ็็็
็็็็็็็็ ็ข็ฅ็ง็ฉ็ช็ซ็ฌ็ญ็ฐ็ณ็ต็ป็ฝ็ฟ็็็็็็ข็ฃ็ฅ"
+ "็ง็ฉ็ซ็ฌ็ญ็ฎ็ฐ็ณ็ถ็ธ็ป็ผ็พ็ฟ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ็ ฃ็ ฅ็ ง็ ซ็ ฌ็ ญ็ ฎ"
+ "็ ฐ็ ด็ ต็ ท็ ธ็ น็ บ็ ป็ ผ็ พ็ก็ก็ก
็ก็ก็ก็ก็ก็ก็ก็ก็ก็ก็ก็ก็ก็ก็กช็กซ็กฌ็กญ็กฎ็กผ็กฟ็ข็ข็ข็ข็ข็ข"
+ "็ข็ข็ข็ข็ข็ข็ข็ข็ข็ข็ขก็ขฃ็ขฅ็ขง็ขจ็ขฐ็ขฑ็ขฒ็ขณ็ขด็ขถ็ขน็ขพ็ฃ็ฃ
็ฃ็ฃ็ฃ็ฃ็ฃ็ฃ็ฃ็ฃ็ฃ็ฃก็ฃจ็ฃฌ็ฃฒ็ฃด็ฃท"
+ "็ฃน็ฃป็ค็ค
็ค็ค็ค็คด็คต็คบ็คผ็คพ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ็ฅ ็ฅข็ฅฅ็ฅง็ฅจ็ฅญ"
+ "็ฅฏ็ฅฒ็ฅท็ฅธ็ฅบ็ฅผ็ฅพ็ฆ็ฆ็ฆ็ฆ
็ฆ็ฆ็ฆ็ฆ็ฆ็ฆ็ฆ็ฆ็ฆค็ฆง็ฆณ็ฆน็ฆบ็ฆป็ฆฝ็ฆพ็ง็ง็ง็ง็ง็ง็ง็ง็ง็ง็ง็ง็งฃ"
+ "็งค็งฆ็งง็งฉ็งซ็งฌ็งญ็งฏ็งฐ็งธ็งป็งฝ็งพ็จ็จ็จ็จ็จ็จ็จ็จ็จ็จ็จ็จ็จ็จ็จ ็จฃ็จณ็จท็จน็จป็จผ็จฝ็จฟ็ฉ็ฉ็ฉ็ฉ"
+ "็ฉ็ฉ็ฉ็ฉฐ็ฉด็ฉถ็ฉท็ฉธ็ฉน็ฉบ็ฉฟ็ช็ช็ช็ช็ช
็ช็ช็ช็ช็ช็ช็ช็ช็ช็ช็ช็ช็ช็ช ็ชฃ็ชฅ็ชฆ็ชจ็ชฌ็ชญ็ชณ็ชธ็ชฟ็ซ"
+ "็ซ็ซ็ซ็ซ็ซ็ซ็ซ ็ซฃ็ซฅ็ซฆ็ซซ็ซญ็ซฏ็ซน็ซบ็ซฝ็ซฟ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ็ฌ ็ฌค็ฌฅ็ฌฆ็ฌจ็ฌช็ฌซ็ฌฌ็ฌฎ็ฌฏ"
+ "็ฌฑ็ฌณ็ฌธ็ฌบ็ฌผ็ฌพ็ญ็ญ
็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ็ญ ็ญข็ญค็ญฅ็ญฆ็ญฎ็ญฑ็ญฒ็ญต็ญถ็ญท็ญน็ญป็ญผ็ญพ็ฎ็ฎ
"
+ "็ฎ็ฎ็ฎ็ฎ็ฎ็ฎ็ฎ็ฎ็ฎก็ฎข็ฎฆ็ฎง็ฎจ็ฎฉ็ฎช็ฎซ็ฎฌ็ฎญ็ฎฑ็ฎด็ฎธ็ฏ็ฏ็ฏ็ฏ็ฏ็ฏ็ฏ็ฏ็ฏ็ฏก็ฏฅ็ฏฆ็ฏช็ฏฎ็ฏฏ็ฏฑ็ฏท็ฏผ็ฏพ"
+ "็ฐ็ฐ็ฐ็ฐ็ฐ็ฐ็ฐ็ฐ็ฐ็ฐ็ฐ ็ฐง็ฐช็ฐฐ็ฐธ็ฐฟ็ฑ็ฑ็ฑ็ฑฅ็ฑณ็ฑด็ฑป็ฑผ็ฑฝ็ฒ็ฒ็ฒ็ฒ็ฒ็ฒ็ฒ็ฒ็ฒ็ฒ็ฒข็ฒค็ฒฅ็ฒช็ฒฎ"
+ "็ฒฑ็ฒฒ็ฒณ็ฒน็ฒผ็ฒฝ็ฒพ็ฒฟ็ณ็ณ
็ณ็ณ็ณ็ณ็ณ็ณ็ณ็ณ็ณ็ณ็ณ็ณ็ณ ็ณจ็ณฏ็ณต็ณป็ด็ด ็ดข็ดง็ดซ็ดฏ็ต็ตฎ็ตท็ถฆ็ถฎ็ธ ็ธข"
+ "็ธป็น็น็น็บ็บ็บ ็บก็บข็บฃ็บค็บฅ็บฆ็บง็บจ็บฉ็บช็บซ็บฌ็บญ็บฎ็บฏ็บฐ็บฑ็บฒ็บณ็บด็บต็บถ็บท็บธ็บน็บบ็บป็บผ็บฝ็บพ็บฟ็ป็ป"
+ "็ป็ป็ป็ป
็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป็ป ็ปก็ปข็ปฃ็ปค็ปฅ็ปฆ็ปง็ปจ็ปฉ"
+ "็ปช็ปซ็ปญ็ปฎ็ปฏ็ปฐ็ปฑ็ปฒ็ปณ็ปด็ปต็ปถ็ปท็ปธ็ปน็ปบ็ปป็ปผ็ปฝ็ปพ็ปฟ็ผ็ผ็ผ็ผ็ผ็ผ
็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ"
+ "็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ็ผ ็ผก็ผข็ผฃ็ผค็ผฅ็ผฆ็ผง็ผจ็ผฉ็ผช็ผซ็ผฌ็ผญ็ผฎ็ผฏ็ผฐ็ผฑ็ผฒ็ผณ็ผด็ผต็ผถ็ผธ็ผบ็ฝ็ฝ็ฝ
็ฝ็ฝ"
+ "็ฝ็ฝ็ฝ็ฝ็ฝ็ฝ็ฝ็ฝก็ฝข็ฝจ็ฝฉ็ฝช็ฝฎ็ฝฑ็ฝฒ็ฝด็ฝถ็ฝน็ฝฝ็ฝพ็พ็พ็พ็พ็พ็พ็พ็พ็พ็พ็พ็พ็พ็พก็พค็พง็พฏ็พฐ็พฑ็พฒ"
+ "็พธ็พน็พผ็พฝ็พฟ็ฟ็ฟ็ฟ็ฟ็ฟ
็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ็ฟ ็ฟก็ฟฅ็ฟฆ็ฟฉ็ฟฎ็ฟฏ็ฟฐ็ฟฑ็ฟณ็ฟท็ฟป็ฟผ็ฟพ่่่่่
"
+ "่่่่่่่่่่่่่่่่่ ่ข่ค่ฅ่ฆ่ง่จ่ฉ่ช่ฐ่ฑ่ณ่ต่ถ่ท่ธ่ป่ฝ่ฟ่่่่่"
+ "่่่่่่่ฉ่ช่ฑ่ฟ่่่่่่่่่่่่่่่ ่ก่ข่ค่ฅ่ฉ่ช่ซ่ญ่ฎ่ฏ่ฑ่ฒ่ด่ท่ธ"
+ "่บ่ผ่ฝ่พ่ฟ่่่่่่่่่่่่่่่่่่่ ่ก่ฃ่ค่ฅ่ง่จ่ฉ่ช่ซ่ฌ่ญ่ฏ่ฐ่ฑ่ฒ่ณ"
+ "่ด่ถ่ธ่บ่ผ่ฝ่่่่่่่่่่่่่่่่่่ฉ่ฌ่ฏ่ฑ่ฒ่ถ่ธ่พ่ฟ่
่
่
่
่
่
่
่
"
+ "่
่
่
่
่
่
่
่
ฅ่
ง่
จ่
ฉ่
ญ่
ฎ่
ฏ่
ฐ่
ฑ่
ด่
น่
บ่
ป่
ผ่
ฝ่
พ่
ฟ่่่่่่่่่่่่ฆ่จ่ณ่บ่ป"
+ "่่่่่่่่่ฃ่ง่ช่ฌ่ญ่ณ่ด่ป่ผ่พ่่่่่
่่่่่่่่่่่ ่ข่ฃ่ฅ่ช่ซ่ฌ"
+ "่ญ่ฏ่ฐ่ฑ่ฒ่ณ่ด่ต่ถ่ท่ธ่น่ป่พ่่
่่่่่่่่่จ่ฎ่ฏ่ฐ่ฒ่ณ่ด่บ่ฝ่พ่ฟ่่่่่"
+ "่่่่่่่่่่่่่ ่ก่ฃ่ค่ฅ่ฆ่จ่ฉ่ช่ซ่ฌ่ญ่ฎ่ฏ่ฐ่ฑ่ณ่ด่ท่ธ่น่ผ่ฝ่พ่่่่"
+ "่่่่่่่่่่่่่่่่่่่ ่ก่ฃ่ค่ฅ่ฆ่ง่ซ่ฏ่ฑ่ด่ท่น่ป่พ่่่่่่
่"
+ "่่่่่่่่่่่่่่่่ง่จ่ซ่ฌ่ญ่ฏ่ฑ่ณ่ด่ต่ถ่ธ่น่บ่ผ่ฝ่่่่่่่่่"
+ "่่่่่่่่่่่่ ่ก่ฃ่ค่ฅ่ฆ่ง่จ่ฉ่ช่ซ่ฌ่ญ่ฎ่ฏ่ท่ธ่ป่ผ่ฝ่
่่่่่่่่"
+ "่่่่ ่จ่ฉ่ช่ซ่ฐ่ฑ่ฒ่ณ่ด่ถ่ท่ธ่น่บ่ผ่ฝ่ฟ่่่่
่่่่่่่่่่่่่ ่ก่ฅ"
+ "่ฉ่ช่ฐ่ฑ่ฒ่น่ผ่ฝ่่่่่่่่่่่่่่่ฃ่ค่ฅ่ฆ่ง่จ่ฉ่ฑ่ณ่ธ่น่ผ่ฝ่่่่่"
+ "่่่่่ก่ฃ่ฉ่ซ่ฌ่ญ่ฐ่ฑ่ณ่ด่ต่ถ่ธ่บ่่่่่่่่่่่่่่ก่จ่ฏ่ฑ่ฒ่ด่ธ่น่บ"
+ "่ป่ฝ่ฟ่่่่่่่่่่่่่่่ ่ข่ฃ่ฅ่ฆ่ฌ่ฐ่ผ่ฟ่่่่่่่่่่่ก่ซ่ฌ่ท"
+ "่ธ่น่บ่ป่ผ่ฝ่่่่่่่่่ค่จ่ฐ่ฒ่ด่น่บ่ป่พ่่่
่่่่่ข่ค่จ่ช่ฎ่ฏ่ฐ่ณ่ท่ธ"
+ "่น่ฟ่่่่่่่่่ ่ค่ฆ่จ่ฉ่ป่ฟ่
่่่่ง่ฉ่ธ่ผ่่่่่่่่่่ข่ค่ซ่ฌ่ฎ่ฑ"
+ "่ท่ธ่น่บ่ป่ผ่ฝ่พ่ฟ่่่่่่่่่่่่่่ฃ่ค่ง่จ่ฉ่ช่ฌ่ฏ่ฐ่ฑ่ฒ่ด่ถ่บ่่่่"
+ "่่่่่่่่่่่่่่่ค่ฉ่ญ่ฎ่ฐ่ฑ่ฒ่ณ่ด่ธ่น่พ่่่่่่่่่่่่่่"
+ "่่่่่ก่ข่ฃ่ฅ่ฉ่ฎ่ฑ่ด่ท่ป่พ่ฟ่่่่่่่่่่ ่ฃ่ค่ฅ่ฎ่ฐ่ฒ่ด่ถ่ป่ผ่ฝ่พ่่"
+ "่
่่่่่่ ่ฃ่จ่ซ่ฌ่ญ่ฏ่ฑ่ณ่ต่บ่ฝ่่่่่่่่่ ่ฅ่ช่ซ่ฎ่น่พ่ ่ ่ ่ ่ ่ ่ ก"
+ "่ ข่ ฒ่ น่ ผ่ก่ก่ก่ก
่ก่ก่ก่ก่ก่ก่ก่ก ่กก่กข่กฃ่กฅ่กจ่กฉ่กซ่กฌ่กฎ่กฐ่กฒ่กท่กฝ่กพ่กฟ่ข่ข่ข่ข
่ข่ข่ข่ข่ข"
+ "่ข่ข่ข่ขข่ขค่ขช่ขซ่ขญ่ขฏ่ขฑ่ขท่ขผ่ฃ่ฃ่ฃ
่ฃ่ฃ่ฃ่ฃ่ฃ่ฃ่ฃ่ฃ่ฃ่ฃ่ฃ่ฃข่ฃฃ่ฃค่ฃฅ่ฃจ่ฃฐ่ฃฑ่ฃณ่ฃด่ฃธ่ฃน่ฃผ่ฃพ่ค"
+ "่ค่ค่ค่ค่ค่ค่ค่ค่ค่คก่คฅ่คช่คซ่คฏ่คฐ่คด่คถ่ฅ่ฅ่ฅ่ฅ่ฅ่ฅ่ฅ่ฅฆ่ฅซ่ฅป่ฅฟ่ฆ่ฆ่ฆ่ง่ง่ง่ง่ง
่ง่ง่ง่ง"
+ "่ง่ง่ง่ง่ง่ง่ง่ง่ง่ง่ง่ง่ง่งฃ่งฅ่งฆ่งซ่งญ่งฏ่งฑ่งณ่งฟ่จ่จ่จ่จ่จพ่ฉ่ฉ่ฉน่ช่ช่ช่ฌ่ญฆ่ญฌ่ฎก่ฎข่ฎฃ่ฎค"
+ "่ฎฅ่ฎฆ่ฎง่ฎจ่ฎฉ่ฎช่ฎซ่ฎญ่ฎฎ่ฎฏ่ฎฐ่ฎฑ่ฎฒ่ฎณ่ฎด่ฎต่ฎถ่ฎท่ฎธ่ฎน่ฎบ่ฎป่ฎผ่ฎฝ่ฎพ่ฎฟ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ
่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ"
+ "่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ่ฏ ่ฏก่ฏข่ฏฃ่ฏค่ฏฅ่ฏฆ่ฏง่ฏจ่ฏฉ่ฏซ่ฏฌ่ฏญ่ฏฎ่ฏฏ่ฏฐ่ฏฑ่ฏฒ่ฏณ่ฏด่ฏต่ฏท"
+ "่ฏธ่ฏน่ฏบ่ฏป่ฏผ่ฏฝ่ฏพ่ฏฟ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ
่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ่ฐ ่ฐก"
+ "่ฐข่ฐฃ่ฐค่ฐฅ่ฐฆ่ฐง่ฐจ่ฐฉ่ฐช่ฐซ่ฐฌ่ฐญ่ฐฎ่ฐฏ่ฐฐ่ฐฑ่ฐฒ่ฐณ่ฐด่ฐต่ฐถ่ฐท่ฐผ่ฐฟ่ฑ่ฑ่ฑ่ฑ่ฑ่ฑ่ฑ่ฑก่ฑข่ฑจ่ฑช่ฑซ่ฑฎ่ฑณ่ฑธ่ฑน"
+ "่ฑบ่ฒ่ฒ
่ฒ่ฒ่ฒ่ฒ่ฒ่ฒ่ด่ด่ด่ดก่ดข่ดฃ่ดค่ดฅ่ดฆ่ดง่ดจ่ดฉ่ดช่ดซ่ดฌ่ดญ่ดฎ่ดฏ่ดฐ่ดฑ่ดฒ่ดณ่ดด่ดต่ดถ่ดท่ดธ่ดน่ดบ่ดป่ดผ"
+ "่ดฝ่ดพ่ดฟ่ต่ต่ต่ต่ต่ต
่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต่ต ่ตก่ตข่ตฃ่ตค"
+ "่ตฆ่ตง่ตช่ตซ่ตญ่ตฐ่ตณ่ตด่ตต่ตถ่ตท่ถ่ถ่ถ
่ถ่ถ่ถ่ถ่ถ่ถฃ่ถฏ่ถฑ่ถณ่ถด่ถต่ถธ่ถบ่ถผ่ถพ่ถฟ่ท่ท่ท่ท่ท่ท่ท่ท่ท่ท"
+ "่ท่ท่ท่ท่ท่ท่ท่ทฃ่ทค่ทจ่ทช่ทฌ่ทฏ่ทฑ่ทณ่ทต่ทถ่ทท่ทธ่ทน่ทบ่ทป่ทฝ่ธ
่ธ่ธ่ธ่ธ่ธ่ธ่ธ่ธ่ธ่ธข่ธฃ่ธฆ่ธฉ่ธช่ธฌ่ธฎ"
+ "่ธฏ่ธฑ่ธต่ธถ่ธน่ธบ่ธฝ่น่น่น่น่น
่น่น่น่น่น่น่น่น่น่น่น่นข่นฆ่นฉ่นฌ่นญ่นฏ่นฐ่นฒ่นด่นถ่นผ่นฝ่นพ่นฟ่บ่บ
่บ"
+ "่บ่บ่บ่บ่บ่บซ่บฌ่บฏ่บฒ่บบ่ฝฆ่ฝง่ฝจ่ฝฉ่ฝช่ฝซ่ฝฌ่ฝญ่ฝฎ่ฝฏ่ฝฐ่ฝฑ่ฝฒ่ฝณ่ฝด่ฝต่ฝถ่ฝท่ฝธ่ฝน่ฝบ่ฝป่ฝผ่ฝฝ่ฝพ่ฝฟ่พ่พ่พ่พ"
+ "่พ่พ
่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พ่พฃ่พจ่พฉ่พซ่พฐ่พฑ่พน่พฝ่พพ่พฟ่ฟ่ฟ่ฟ"
+ "่ฟ
่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟ่ฟข่ฟค่ฟฅ่ฟฆ่ฟจ่ฟฉ่ฟช่ฟซ่ฟญ่ฟฎ่ฟฐ่ฟณ่ฟท่ฟธ่ฟน่ฟบ่ฟฝ้้้้้้
้"
+ "้้้้้้้้้้้้้้้้้ ้ก้ข้ฆ้ญ้ฎ้ฏ้ด้ต้ถ้ธ้ป้ผ้พ้้้้้้้้้้"
+ "้้้้้ข้ฃ้ฅ้จ้ญ้ฎ้ด้ต้น้ฝ้ฟ้้้้้้้้้้้้้้ ้ก้ข้ฃ้ฆ้จ้ช้ฌ้ฎ้ฏ้ฐ้ฑ"
+ "้ฒ้ณ้ด้ต้ถ้ธ้น้บ้ป้ฝ้พ้ฟ้้้้
้้้้้้้้้้้้้้ก้ข้ค้ฆ้ง้จ้ช้ซ้ญ้ฏ้ด"
+ "้ธ้ฝ้พ้ฟ้้้้้
้้้้้้้้้ ้ข้ฃ้ซ้ฏ้ฑ้น้
้
้
้
้
้
้
้
้
้
้
้
้
้
้
้
"
+ "้
้
ก้
ข้
ฃ้
ค้
ฅ้
ฆ้
ฉ้
ช้
ฌ้
ฎ้
ฏ้
ฐ้
ฑ้
ฒ้
ด้
ต้
ถ้
ท้
ธ้
น้
บ้
ฝ้
พ้
ฟ้
้้้้้้้้้้้ข้จ้ช้ญ"
+ "้ฎ้ฏ้ด้ต้บ้พ้้้้้้้้้้้ด้้ฎ้้้พ้ช้้้้พ้ซ้้้้้้้้้้้้"
+ "้้้้้้้้้้้้้้ ้ก้ข้ฃ้ค้ฅ้ฆ้ง้จ้ฉ้ช้ซ้ฌ้ญ้ฎ้ฏ้ฐ้ฑ้ฒ้ณ้ด้ต้ท้น้บ้ป้ผ"
+ "้ฝ้พ้ฟ้้้้้้
้้้้้้้้้้้้้้้้้้้้้้้้ ้ก้ข้ฃ้ค้ฅ้ง้จ"
+ "้ฉ้ช้ซ้ฌ้ญ้ฎ้ฏ้ฐ้ฑ้ฒ้ณ้ด้ต้ถ้ท้ธ้น้บ้ป้ผ้ฝ้พ้ฟ้้้้้้
้้้้้้้้้้้"
+ "้้้้้้้้้้้้้้้้ก้ข้ฃ้ค้ฅ้ฆ้ง้จ้ฉ้ช้ซ้ฌ้ญ้ฎ้ฏ้ฐ้ฑ้ฒ้ณ้ด้ต้ถ้ท้ธ้น"
+ "้บ้ป้ผ้ฝ้พ้ฟ้้้้้้
้้้้้้้้้้้้้้้้้้้้้้้้้ ้ก้ข้ฃ"
+ "้ค้ฅ้ฆ้ง้จ้ฉ้ช้ซ้ฌ้ญ้ฎ้ฏ้ฐ้ฑ้ฒ้ณ้ด้ต้ถ้ฟ้จ้ฉ้ช้ซ้ญ้ฎ้ฏ้ฐ้ฑ้ฒ้ณ้ด้ต้ถ้ท้ธ้น้บ้ป้ผ"
+ "้ฝ้พ้ฟ้้้้้้
้้้้้้้้้้้้้้้้้้้้้้้ก้ช้ฎ้ฑ้ฒ้ณ้ด้ต้ถ"
+ "้ป้ผ้ฝ้ฟ้้้้
้้้้้้้้้้้้้้้้ก้ข้ค้ง้จ้ฉ้ช้ฌ้ฒ้ด้ต้ถ้ท้้
้้"
+ "้้้้้้้้้้ง้ฉ้ฐ้ณ้ถ้น้บ้ผ้ฝ้พ้้้้
้้้้้้้้้้้ ้จ้ฉ้ช้ฏ้ฑ้ณ"
+ "้ถ้ท้น้พ้้้้
้้้้้้้้้้้้จ้ช้ญ้ฐ้ฒ้ธ้น้พ้้้้้้้ ้ก้ข้ฅ้ฉ้ฌ้ฐ"
+ "้ณ้ด้ถ้ธ้บ้ผ้ฝ้ฟ้้
้้้้้้้ ้ก้ฃ้ง้จ้ซ้ฌ้ญ้ฎ้ฏ้ฒ้ณ้ด้้ฆ้ง้จ้ฉ้ช้ซ้ฌ้ญ้ณ้ต"
+ "้ถ้กต้กถ้กท้กธ้กน้กบ้กป้กผ้กฝ้กพ้กฟ้ข้ข้ข้ข้ข้ข
้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข้ข"
+ "้ข้ข้ข ้ขก้ขข้ขค้ขฅ้ขฆ้ขง้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃ้ฃง้ฃจ้ค้ค้คฎ้ฅ้ฅ้ฅฅ้ฅง้ฅจ้ฅฉ้ฅช้ฅซ้ฅฌ้ฅญ้ฅฎ้ฅฏ้ฅฐ"
+ "้ฅฑ้ฅฒ้ฅณ้ฅด้ฅต้ฅถ้ฅท้ฅธ้ฅน้ฅบ้ฅป้ฅผ้ฅฝ้ฅฟ้ฆ้ฆ้ฆ้ฆ
้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆ้ฆฅ"
+ "้ฆง้ฆจ้ฉฌ้ฉญ้ฉฎ้ฉฏ้ฉฐ้ฉฑ้ฉฒ้ฉณ้ฉด้ฉต้ฉถ้ฉท้ฉธ้ฉน้ฉบ้ฉป้ฉผ้ฉฝ้ฉพ้ฉฟ้ช้ช้ช้ช้ช้ช
้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช"
+ "้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช้ช ้ชก้ชข้ชฃ้ชค้ชฅ้ชฆ้ชง้ชจ้ชฐ้ชฑ้ชถ้ชท้ชธ้ชบ้ชผ้ซ้ซ้ซ้ซ้ซ
้ซ้ซ้ซ้ซ้ซ้ซ"
+ "้ซก้ซข้ซฆ้ซซ้ซญ้ซฏ้ซน้ซป้ซฝ้ฌ้ฌ้ฌ้ฌ้ฌ้ฌ้ฌ้ฌฃ้ฌฏ้ฌฒ้ฌถ้ฌท้ฌป้ฌผ้ญ้ญ้ญ้ญ้ญ
้ญ้ญ้ญ้ญ้ญ้ญ้ญ้ญ้ญ้ฑผ้ฑฝ้ฑพ"
+ "้ฑฟ้ฒ้ฒ้ฒ้ฒ้ฒ
้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ้ฒ ้ฒก้ฒข้ฒฃ้ฒค้ฒฅ้ฒฆ้ฒง้ฒจ"
+ "้ฒฉ้ฒช้ฒซ้ฒฌ้ฒญ้ฒฎ้ฒฏ้ฒฐ้ฒฑ้ฒฒ้ฒณ้ฒด้ฒต้ฒท้ฒธ้ฒน้ฒบ้ฒป้ฒผ้ฒฝ้ฒพ้ฒฟ้ณ้ณ้ณ้ณ้ณ้ณ
้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ"
+ "้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ้ณ ้ณก้ณข้ณฃ้ณค้ธ้ธ ้ธก้ธข้ธฃ้ธค้ธฅ้ธฆ้ธง้ธจ้ธฉ้ธช้ธซ้ธฌ้ธญ้ธฎ้ธฏ้ธฐ้ธฑ้ธฒ้ธณ้ธต้ธถ"
+ "้ธท้ธธ้ธน้ธบ้ธป้ธผ้ธฝ้ธพ้ธฟ้น้น้น้น้น้น
้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น้น"
+ "้น ้นก้นข้นฃ้นค้นฆ้นง้นจ้นฉ้นช้นซ้นฌ้นญ้นฎ้นฏ้นฐ้นฑ้นฒ้นณ้นด้นพ้นฟ้บ้บ้บ้บ้บ้บ้บ้บ้บ้บ้บ้บฆ้บธ้บน้บป้บฝ้บพ้ป"
+ "้ป้ป้ป้ป้ป้ป้ป้ป้ป้ป้ป้ป้ป ้ปก้ปข้ปฅ้ปง้ปฉ้ปช้ปฏ้ปน้ปป้ปผ้ปพ้ผ้ผ้ผ้ผ้ผ้ผ้ผ้ผ้ผ ้ผข้ผฉ้ผซ้ผฌ้ผฏ้ผฑ้ผท"
+ "้ผน้ผป้ผฝ้ผพ้ฝ้ฝ้ฝ้ฝ้ฝ้ฝฟ้พ้พ้พ้พ้พ้พ
้พ้พ้พ้พ้พ้พ้พ้พ้พ้พ้พ้พ ้พข้ฟ้ฟ้ฟใใใฎใใใใฆใ"
+ "ใในใใ ใ ใคใฅใงใงใงใซฐใฌใฌใฌใญใญใฎพใฐใณใณใณใดใตใถฒใธใธใบใปฌใฝใฟ ไไฎไ
ไไ
ไนไไไไก"
+ "ไฒไไไไจไซไฌไไไชไดไฃไไขบไขผไฃไฅฝไฆไฒไฒ ไฒขไดไดไดไดไดไดไดไถฎ๐
ค๐ ถ๐ ณ๐ก๐ก๐ฃ๐ฃฒ๐ฃฒ๐ฃธฃ๐คง๐คฉฝ"
+ "๐คซ๐ฅฒ๐ฅข๐ฅจ๐ฅป๐ฆก๐ฆ๐ฆถ๐ฆผ๐ฆญ๐ฆฐก๐งฟน๐จ๐จธ๐จ๐จ ๐จญ๐จฑ๐จฑ๐จฑ๐จฑ๐จบ๐ฉฝพ๐ฉพ๐ฉพ๐ช๐ชฃป๐ชค๐ชจฐ๐ชจถ๐ชฉ๐ชพข๐ซง๐ซจ๐ซท๐ซธ๐ซญ๐ซ๐ซฃ๐ซฏ"
+ "๐ซฒ๐ซฝ๐ซ๐ซ๐ซ๐ซก๐ซง๐ซฏ๐ซถ๐ซน๐ซ๐ซ๐ซถ๐ซฎ๐ซฏ๐ซณ๐ซง๐ซด๐ซ๐ซ๐ซฆ๐ซง๐ซจ๐ซช๐ซฌ๐ซ๐ซ๐ซญ๐ซญ๐ซฉ๐ซ
๐ซฆ๐ซน๐ซผ๐ซ ๐ซ ๐ซ ๐ซขธ๐ซซ๐ซญ"
+ "๐ซญข๐ซญผ๐ซฎ๐ซฐ๐ซตท๐ซถ๐ซทท๐ซธฉ๐ฌฉ๐ฌช๐ฌฉ๐ฌ๐ฌ๐ฌ๐ฌน๐ฌผ๐ฌ๐ฌค๐ฌ๐ฌ๐ฌก๐ฌค๐ฌ๐ฌ๐ฌ๐ฌ๐ฌ๐ฌ๐ฌก๐ฌฉ๐ฌซ๐ฌฌ๐ฌญ๐ฌฏ๐ฌ๐ฌ๐ฌ๐ฌฌ๐ฌฏ๐ฌ"
+ "๐ฌ๐ฌฝ๐ฌฃ๐ฌฃ๐ฌฃก๐ฌฃณ๐ฌค๐ฌค๐ฌค๐ฌจ๐ฌจ๐ฌฉฝ๐ฌชฉ๐ฌฌฉ๐ฌฌญ๐ฌฌฎ๐ฌฌฑ๐ฌฌธ๐ฌฌน๐ฌฌป๐ฌฌฟ๐ฌญ๐ฌญ๐ฌญ๐ฌญ๐ฌญ๐ฌญค๐ฌญฉ๐ฌญฌ๐ฌญฏ๐ฌญณ๐ฌญถ๐ฌญธ๐ฌญผ๐ฌฎฑ๐ฌฎฟ๐ฌฏ๐ฌฏ๐ฌฑ๐ฌฑ"
+ "๐ฌณต๐ฌณถ๐ฌณฝ๐ฌณฟ๐ฌด๐ฌด๐ฌด๐ฌถ๐ฌถ๐ฌถ๐ฌถ๐ฌถ๐ฌถ ๐ฌถจ๐ฌถญ๐ฌถฎ๐ฌท๐ฌธ๐ฌธ๐ฌธฃ๐ฌธฆ๐ฌธช๐ฌนผ๐ฌบ๐ฌบ"
+)
+CN_CHARS_EXT = "ๅถ่ฏถๅฑๅง้ฃๅฑ"
+
+CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT
+IN_CH_CHARS = {c: True for c in CN_CHARS}
+
+EN_CHARS = string.ascii_letters + string.digits
+IN_EN_CHARS = {c: True for c in EN_CHARS}
+
+VALID_CHARS = CN_CHARS + EN_CHARS + " "
+IN_VALID_CHARS = {c: True for c in VALID_CHARS}
+
+
+# ================================================================================ #
+# basic class
+# ================================================================================ #
+class ChineseChar(object):
+ """
+ ไธญๆๅญ็ฌฆ
+ ๆฏไธชๅญ็ฌฆๅฏนๅบ็ฎไฝๅ็นไฝ,
+ e.g. ็ฎไฝ = '่ด', ็นไฝ = '่ฒ '
+ ่ฝฌๆขๆถๅฏ่ฝฌๆขไธบ็ฎไฝๆ็นไฝ
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ # self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ ไธญๆๆฐๅญ/ๆฐไฝๅญ็ฌฆ
+ ๆฏไธชๅญ็ฌฆ้ค็น็ฎไฝๅค่ฟๆไธไธช้ขๅค็ๅคงๅๅญ็ฌฆ
+ e.g. '้' ๅ '้ธ'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return "10^{}".format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+ if small_unit:
+ return ChineseNumberUnit(power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ else:
+ raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type))
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ ไธญๆๆฐๅญๅญ็ฌฆ
+ """
+
+ def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ ไธญๆๆฐไฝๅญ็ฌฆ
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ ไธญๆๆฐๅญ็ณป็ป
+ """
+
+ pass
+
+
+class MathSymbol(object):
+ """
+ ็จไบไธญๆๆฐๅญ็ณป็ป็ๆฐๅญฆ็ฌฆๅท (็น/็ฎไฝ), e.g.
+ positive = ['ๆญฃ', 'ๆญฃ']
+ negative = ['่ด', '่ฒ ']
+ point = ['็น', '้ป']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# ๅ
ถไป็ฌฆๅท
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
+
+
+# ================================================================================ #
+# basic utils
+# ================================================================================ #
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ ๆ นๆฎๆฐๅญ็ณป็ป็ฑปๅ่ฟๅๅๅปบ็ธๅบ็ๆฐๅญ็ณป็ป๏ผ้ป่ฎคไธบ mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: ไธญๆๆฐๅญ็ณป็ป็ฑปๅ
+ low: 'ๅ
' = 'ไบฟ' * 'ๅ' = $10^{9}$, 'ไบฌ' = 'ๅ
' * 'ๅ', etc.
+ mid: 'ๅ
' = 'ไบฟ' * 'ไธ' = $10^{12}$, 'ไบฌ' = 'ๅ
' * 'ไธ', etc.
+ high: 'ๅ
' = 'ไบฟ' * 'ไบฟ' = $10^{16}$, 'ไบฌ' = 'ๅ
' * 'ๅ
', etc.
+ ่ฟๅๅฏนๅบ็ๆฐๅญ็ณป็ป
+ """
+
+ # chinese number units of 'ไบฟ' and larger
+ all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+ larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
+ # chinese number units of 'ๅ, ็พ, ๅ, ไธ'
+ all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+ smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
+ # digis
+ chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ""
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ ไธ็พๅ
ซ to ไธ็พๅ
ซๅ
+ ไธไบฟไธๅไธ็พไธ to ไธไบฟ ไธๅไธ ไธ็พไธ
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
+ integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
+ result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None)
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. 'ไธคๅไธ' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = "".join([str(d.value) for d in dec_part])
+ if dec_part:
+ return "{0}.{1}".format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, traditional=False, alt_zero=False, alt_one=False, alt_two=True, use_zeros=True, use_units=True):
+ def get_value(value_string, use_zeros=True):
+ striped_string = value_string.lstrip("0")
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
+ result_string = value_string[: -result_unit.power]
+ return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :])
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split(".")
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError("invalid input num string with more than one dot: {}".format(number_string))
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t)
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+ if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
+ result_symbols[i] = liang
+
+ # if big is True, 'ไธค' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = "big_"
+ if traditional:
+ attr_name += "t"
+ else:
+ attr_name += "s"
+ else:
+ if traditional:
+ attr_name = "traditional"
+ else:
+ attr_name = "simplified"
+
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+
+ if alt_one:
+ result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
+ result = result[1:]
+
+ return result
+
+
+# ================================================================================ #
+# different types of rewriters
+# ================================================================================ #
+class Cardinal:
+ """
+ CARDINAL็ฑป
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+
+class Digit:
+ """
+ DIGIT็ฑป
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+class TelePhone:
+ """
+ TELEPHONE็ฑป
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+ if fixed:
+ sil_parts = self.telephone.split("-")
+ self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts])
+ self.chntext = self.raw_chntext.replace("", "")
+ else:
+ sp_parts = self.telephone.strip("+").split()
+ self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts])
+ self.chntext = self.raw_chntext.replace("", "")
+ return self.chntext
+
+
+class Fraction:
+ """
+ FRACTION็ฑป
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split("ๅไน")
+ return chn2num(numerator) + "/" + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split("/")
+ return num2chn(denominator) + "ๅไน" + num2chn(numerator)
+
+
+class Date:
+ """
+ DATE็ฑป
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('ๅนด', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + 'ๅนด'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('ๆ', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + 'ๆ'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split("ๅนด", 1)
+ year = Digit(digit=year).digit2chntext() + "ๅนด"
+ except ValueError:
+ other = date
+ year = ""
+ if other:
+ try:
+ month, day = other.strip().split("ๆ", 1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + "ๆ"
+ except ValueError:
+ day = date
+ month = ""
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ""
+ day = ""
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+class Money:
+ """
+ MONEY็ฑป
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
+ self.chntext = money
+ return self.chntext
+
+
+class Percentage:
+ """
+ PERCENTAGE็ฑป
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip("็พๅไน")) + "%"
+
+ def percentage2chntext(self):
+ return "็พๅไน" + num2chn(self.percentage.strip().strip("%"))
+
+
+def normalize_nsw(raw_text):
+ text = "^" + raw_text + "$"
+
+ # ่ง่ๅๆฅๆ
+ pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})ๅนด)?(\d{1,2}ๆ(\d{1,2}[ๆฅๅท])?)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # ่ง่ๅ้้ฑ
+ pattern = re.compile(r"\D+((\d+(\.\d+)?)[ๅคไฝๅ ]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('money')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
+
+ # ่ง่ๅๅบ่ฏ/ๆๆบๅท็
+ # ๆๆบ
+ # http://www.jihaoba.com/news/show/13680
+ # ็งปๅจ๏ผ139ใ138ใ137ใ136ใ135ใ134ใ159ใ158ใ157ใ150ใ151ใ152ใ188ใ187ใ182ใ183ใ184ใ178ใ198
+ # ่้๏ผ130ใ131ใ132ใ156ใ155ใ186ใ185ใ176
+ # ็ตไฟก๏ผ133ใ153ใ189ใ180ใ181ใ177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('telephone')
+ for matcher in matchers:
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
+ # ๅบ่ฏ
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
+
+ # ่ง่ๅๅๆฐ
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fraction')
+ for matcher in matchers:
+ text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
+
+ # ่ง่ๅ็พๅๆฐ
+ text = text.replace("๏ผ
", "%")
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('percentage')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
+
+ # ่ง่ๅ็บฏๆฐ+้่ฏ
+ pattern = re.compile(r"(\d+(\.\d+)?)[ๅคไฝๅ ]?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+ # ่ง่ๅๆฐๅญ็ผๅท
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # ่ง่ๅ็บฏๆฐ
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+ # restore P2P, O2O, B2C, B2B etc
+ pattern = re.compile(r"(([a-zA-Z]+)ไบ([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
+
+ return text.lstrip("^").rstrip("$")
+
+
+def remove_erhua(text):
+ """
+ ๅป้คๅฟๅ้ณ่ฏไธญ็ๅฟ:
+ ไปๅฅณๅฟๅจ้ฃ่พนๅฟ -> ไปๅฅณๅฟๅจ้ฃ่พน
+ """
+
+ new_str = ""
+ while re.search("ๅฟ", text):
+ a = re.search("ๅฟ", text).span()
+ remove_er_flag = 0
+
+ if ER_WHITELIST_PATTERN.search(text):
+ b = ER_WHITELIST_PATTERN.search(text).span()
+ if b[0] <= a[0]:
+ remove_er_flag = 1
+
+ if remove_er_flag == 0:
+ new_str = new_str + text[0 : a[0]]
+ text = text[a[1] :]
+ else:
+ new_str = new_str + text[0 : b[1]]
+ text = text[b[1] :]
+
+ text = new_str + text
+ return text
+
+
+def remove_space(text):
+ tokens = text.split()
+ new = []
+ for k, t in enumerate(tokens):
+ if k != 0:
+ if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]):
+ new.append(" ")
+ new.append(t)
+ return "".join(new)
+
+
+class TextNorm:
+ def __init__(
+ self,
+ to_banjiao: bool = False,
+ to_upper: bool = False,
+ to_lower: bool = False,
+ remove_fillers: bool = False,
+ remove_erhua: bool = False,
+ check_chars: bool = False,
+ remove_space: bool = False,
+ cc_mode: str = "",
+ ):
+ self.to_banjiao = to_banjiao
+ self.to_upper = to_upper
+ self.to_lower = to_lower
+ self.remove_fillers = remove_fillers
+ self.remove_erhua = remove_erhua
+ self.check_chars = check_chars
+ self.remove_space = remove_space
+
+ self.cc = None
+ if cc_mode:
+ from opencc import OpenCC # Open Chinese Convert: pip install opencc
+
+ self.cc = OpenCC(cc_mode)
+
+ def __call__(self, text):
+ if self.cc:
+ text = self.cc.convert(text)
+
+ if self.to_banjiao:
+ text = text.translate(QJ2BJ_TRANSFORM)
+
+ if self.to_upper:
+ text = text.upper()
+
+ if self.to_lower:
+ text = text.lower()
+
+ if self.remove_fillers:
+ for c in FILLER_CHARS:
+ text = text.replace(c, "")
+
+ if self.remove_erhua:
+ text = remove_erhua(text)
+
+ text = normalize_nsw(text)
+
+ text = text.translate(PUNCS_TRANSFORM)
+
+ if self.check_chars:
+ for c in text:
+ if not IN_VALID_CHARS.get(c):
+ print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
+ return ""
+
+ if self.remove_space:
+ text = remove_space(text)
+
+ return text
+
+
+if __name__ == "__main__":
+ p = argparse.ArgumentParser()
+
+ # normalizer options
+ p.add_argument("--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao")
+ p.add_argument("--to_upper", action="store_true", help="convert to upper case")
+ p.add_argument("--to_lower", action="store_true", help="convert to lower case")
+ p.add_argument("--remove_fillers", action="store_true", help='remove filler chars such as "ๅ, ๅ"')
+ p.add_argument("--remove_erhua", action="store_true", help='remove erhua chars such as "ไปๅฅณๅฟๅจ้ฃ่พนๅฟ -> ไปๅฅณๅฟๅจ้ฃ่พน"')
+ p.add_argument("--check_chars", action="store_true", help="skip sentences containing illegal chars")
+ p.add_argument("--remove_space", action="store_true", help="remove whitespace")
+ p.add_argument("--cc_mode", choices=["", "t2s", "s2t"], default="", help="convert between traditional to simplified")
+
+ # I/O options
+ p.add_argument("--log_interval", type=int, default=10000, help="log interval in number of processed lines")
+ p.add_argument("--has_key", action="store_true", help="will be deprecated, set --format ark instead")
+ p.add_argument("--format", type=str, choices=["txt", "ark", "tsv"], default="txt", help="input format")
+ p.add_argument("ifile", help="input filename, assume utf-8 encoding")
+ p.add_argument("ofile", help="output filename")
+
+ args = p.parse_args()
+
+ if args.has_key:
+ args.format = "ark"
+
+ normalizer = TextNorm(
+ to_banjiao=args.to_banjiao,
+ to_upper=args.to_upper,
+ to_lower=args.to_lower,
+ remove_fillers=args.remove_fillers,
+ remove_erhua=args.remove_erhua,
+ check_chars=args.check_chars,
+ remove_space=args.remove_space,
+ cc_mode=args.cc_mode,
+ )
+
+ ndone = 0
+ with open(args.ifile, "r", encoding="utf8") as istream, open(args.ofile, "w+", encoding="utf8") as ostream:
+ if args.format == "tsv":
+ reader = csv.DictReader(istream, delimiter="\t")
+ assert "TEXT" in reader.fieldnames
+ print("\t".join(reader.fieldnames), file=ostream)
+
+ for item in reader:
+ text = item["TEXT"]
+
+ if text:
+ text = normalizer(text)
+
+ if text:
+ item["TEXT"] = text
+ print("\t".join([item[f] for f in reader.fieldnames]), file=ostream)
+
+ ndone += 1
+ if ndone % args.log_interval == 0:
+ print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
+ else:
+ for l in istream:
+ key, text = "", ""
+ if args.format == "ark": # KALDI archive, line format: "key text"
+ cols = l.strip().split(maxsplit=1)
+ key, text = cols[0], cols[1] if len(cols) == 2 else ""
+ else:
+ text = l.strip()
+
+ if text:
+ text = normalizer(text)
+
+ if text:
+ if args.format == "ark":
+ print(key + "\t" + text, file=ostream)
+ else:
+ print(text, file=ostream)
+
+ ndone += 1
+ if ndone % args.log_interval == 0:
+ print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
+ print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True)
diff --git a/lmms_eval/tasks/librispeech/librispeech.yaml b/lmms_eval/tasks/librispeech/librispeech.yaml
new file mode 100755
index 00000000..dd6be3bc
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/librispeech.yaml
@@ -0,0 +1,6 @@
+group: librispeech
+task:
+ - librispeech_dev_clean
+ - librispeech_dev_other
+ - librispeech_test_clean
+ - librispeech_test_other
diff --git a/lmms_eval/tasks/librispeech/librispeech_dev_clean.yaml b/lmms_eval/tasks/librispeech/librispeech_dev_clean.yaml
new file mode 100644
index 00000000..237d955e
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/librispeech_dev_clean.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/librispeech
+dataset_kwargs:
+ token: True
+task : "librispeech_dev_clean"
+test_split: librispeech_dev_clean
+dataset_name: librispeech_dev_clean
+output_type: generate_until
+doc_to_visual: !function utils.librispeech_doc_to_audio
+doc_to_text: !function utils.librispeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.librispeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.librispeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/librispeech/librispeech_dev_other.yaml b/lmms_eval/tasks/librispeech/librispeech_dev_other.yaml
new file mode 100644
index 00000000..c4426db5
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/librispeech_dev_other.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/librispeech
+dataset_kwargs:
+ token: True
+task : "librispeech_dev_other"
+test_split: librispeech_dev_other
+dataset_name: librispeech_dev_other
+output_type: generate_until
+doc_to_visual: !function utils.librispeech_doc_to_audio
+doc_to_text: !function utils.librispeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.librispeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.librispeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/librispeech/librispeech_test_clean.yaml b/lmms_eval/tasks/librispeech/librispeech_test_clean.yaml
new file mode 100644
index 00000000..bf1c680e
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/librispeech_test_clean.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/librispeech
+dataset_kwargs:
+ token: True
+task : "librispeech_test_clean"
+test_split: librispeech_test_clean
+dataset_name: librispeech_test_clean
+output_type: generate_until
+doc_to_visual: !function utils.librispeech_doc_to_audio
+doc_to_text: !function utils.librispeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.librispeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.librispeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/librispeech/librispeech_test_other.yaml b/lmms_eval/tasks/librispeech/librispeech_test_other.yaml
new file mode 100644
index 00000000..ccac2ec0
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/librispeech_test_other.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/librispeech
+dataset_kwargs:
+ token: True
+task : "librispeech_test_other"
+test_split: librispeech_test_other
+dataset_name: librispeech_test_other
+output_type: generate_until
+doc_to_visual: !function utils.librispeech_doc_to_audio
+doc_to_text: !function utils.librispeech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.librispeech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.librispeech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/librispeech/utils.py b/lmms_eval/tasks/librispeech/utils.py
new file mode 100755
index 00000000..dfe4109b
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/utils.py
@@ -0,0 +1,217 @@
+import os
+import re
+import unicodedata
+
+import editdistance as ed # TODO: new package
+import zhconv # TODO: new package
+
+from lmms_eval.tasks.librispeech.cn_tn import TextNorm
+from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer
+from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer
+
+# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'.
+english_normalizer = EnglishTextNormalizer()
+chinese_normalizer = TextNorm(
+ to_banjiao=False,
+ to_upper=False,
+ to_lower=False,
+ remove_fillers=False,
+ remove_erhua=False,
+ check_chars=False,
+ remove_space=False,
+ cc_mode="",
+)
+basic_normalizer = BasicTextNormalizer()
+
+dir_name = os.path.dirname(os.path.abspath(__file__))
+
+
+def librispeech_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def librispeech_doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"
+
+
+def librispeech_process_result(doc, result):
+ pred = result[0] if len(result) > 0 else ""
+
+ gt = doc["gt"]
+ source = doc["source"]
+ task = doc["task"]
+
+ data_dict = {"gt": gt, "pred": pred, "source": source, "task": task}
+
+ return {"wer": data_dict}
+
+
+PUNCS = "!,.?;:"
+
+
+def remove_sp(text, language):
+ gt = re.sub(r"<\|.*?\|>", " ", text)
+ gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space.
+ gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
+ gt = gt.lstrip(" ")
+ if language == "zh":
+ gt = re.sub(rf"\s+", r"", gt)
+ return gt
+
+
+class EvaluationTokenizer(object):
+ """A generic evaluation-time tokenizer, which leverages built-in tokenizers
+ in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
+ lowercasing, punctuation removal and character tokenization, which are
+ applied after sacreBLEU tokenization.
+
+ Args:
+ tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
+ lowercase (bool): lowercase the text.
+ punctuation_removal (bool): remove punctuation (based on unicode
+ category) from text.
+ character_tokenization (bool): tokenize the text to characters.
+ """
+
+ SPACE = chr(32)
+ SPACE_ESCAPE = chr(9601)
+ # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
+
+ def __init__(
+ self,
+ tokenizer_type: str = "13a",
+ lowercase: bool = False,
+ punctuation_removal: bool = False,
+ character_tokenization: bool = False,
+ ):
+ # from sacrebleu.tokenizers import TOKENIZERS
+ # from sacrebleu.tokenizers import tokenizer_none
+ from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
+ from sacrebleu.tokenizers.tokenizer_char import TokenizerChar
+ from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International
+ from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab
+ from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer
+ from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh
+
+ TOKENIZERS = {
+ "none": NoneTokenizer,
+ "13a": Tokenizer13a,
+ "intl": TokenizerV14International,
+ "zh": TokenizerZh,
+ "ja-mecab": TokenizerJaMecab,
+ "char": TokenizerChar,
+ }
+
+ assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
+ self.lowercase = lowercase
+ self.punctuation_removal = punctuation_removal
+ self.character_tokenization = character_tokenization
+ self.tokenizer = TOKENIZERS[tokenizer_type]
+ # self.tokenizer = tokenizer_none
+
+ @classmethod
+ def remove_punctuation(cls, sent: str):
+ """Remove punctuation based on Unicode category."""
+ return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t))
+
+ def tokenize(self, sent: str):
+ tokenized = self.tokenizer()(sent)
+
+ if self.punctuation_removal:
+ tokenized = self.remove_punctuation(tokenized)
+
+ if self.character_tokenization:
+ tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
+
+ if self.lowercase:
+ tokenized = tokenized.lower()
+
+ return tokenized
+
+
+def compute_wer(refs, hyps, language):
+ distance = 0
+ ref_length = 0
+ tokenizer = EvaluationTokenizer(
+ tokenizer_type="none",
+ lowercase=True,
+ punctuation_removal=True,
+ character_tokenization=False,
+ )
+ for i in range(len(refs)):
+ ref = refs[i]
+ pred = hyps[i]
+ if language in ["yue"]:
+ ref = zhconv.convert(ref, "zh-cn")
+ pred = zhconv.convert(pred, "zh-cn")
+ if language in ["en"]:
+ ref = english_normalizer(ref)
+ pred = english_normalizer(pred)
+ if language in ["zh"]:
+ ref = chinese_normalizer(ref)
+ pred = chinese_normalizer(pred)
+ else:
+ ref = basic_normalizer(ref)
+ pred = basic_normalizer(pred)
+ ref_items = tokenizer.tokenize(ref).split()
+ pred_items = tokenizer.tokenize(pred).split()
+ if language in ["zh", "yue"]:
+ ref_items = [x for x in "".join(ref_items)]
+ pred_items = [x for x in "".join(pred_items)]
+ if i == 0:
+ print(f"ref: {ref}")
+ print(f"pred: {pred}")
+ print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
+ print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
+ distance += ed.eval(ref_items, pred_items)
+ ref_length += len(ref_items)
+ return distance / ref_length
+
+
+def librispeech_wer(results, args):
+ # lan = args["language"]
+ refs, hyps = [], []
+ # results_list = results_dict[source]
+ for result in results:
+ lan = result["task"][4:]
+ gt = result["gt"]
+ response = result["pred"]
+ gt = remove_sp(gt, lan)
+ response = remove_sp(response, lan)
+ refs.append(gt)
+ hyps.append(response)
+ wer = compute_wer(refs, hyps, lan)
+ # print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}")
+ return wer * 100
+
+ # for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths):
+ # results.append({
+ # 'gt': gt,
+ # 'response': response,
+ # 'source': source,
+ # 'audio_path': audio_path,
+ # })
+ # time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
+ # results_file = f'{args.dataset}_{time_prefix}.json'
+ # json.dump(results, open(results_file, 'w'))
+ results_dict = {}
+ for item in results:
+ source = item["source"]
+ results_dict.setdefault(source, []).append(item)
+ lan = ds_collections[args.dataset]["language"]
+ for source in results_dict:
+ refs, hyps = [], []
+ results_list = results_dict[source]
+ for result in results_list:
+ gt = result["gt"]
+ response = result["response"]
+ gt = remove_sp(gt, lan)
+ response = remove_sp(response, lan)
+ refs.append(gt)
+ hyps.append(response)
+ wer = compute_wer(refs, hyps, lan)
+ print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}")
+
+ pass
diff --git a/lmms_eval/tasks/librispeech/whisper_normalizer/basic.py b/lmms_eval/tasks/librispeech/whisper_normalizer/basic.py
new file mode 100644
index 00000000..00a54dcc
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/whisper_normalizer/basic.py
@@ -0,0 +1,58 @@
+import re
+import unicodedata
+
+import regex
+
+# non-ASCII letters that are not separated by "NFKD" normalization
+ADDITIONAL_DIACRITICS = {
+ "ล": "oe",
+ "ล": "OE",
+ "รธ": "o",
+ "ร": "O",
+ "รฆ": "ae",
+ "ร": "AE",
+ "ร": "ss",
+ "แบ": "SS",
+ "ฤ": "d",
+ "ฤ": "D",
+ "รฐ": "d",
+ "ร": "D",
+ "รพ": "th",
+ "ร": "th",
+ "ล": "l",
+ "ล": "L",
+}
+
+
+def remove_symbols_and_diacritics(s: str, keep=""):
+ """
+ Replace any other markers, symbols, and punctuations with a space,
+ and drop any diacritics (category 'Mn' and some manual mappings)
+ """
+ return "".join(c if c in keep else ADDITIONAL_DIACRITICS[c] if c in ADDITIONAL_DIACRITICS else "" if unicodedata.category(c) == "Mn" else " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKD", s))
+
+
+def remove_symbols(s: str):
+ """
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
+ """
+ return "".join(" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s))
+
+
+class BasicTextNormalizer:
+ def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
+ self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
+ self.split_letters = split_letters
+
+ def __call__(self, s: str):
+ s = s.lower()
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
+ s = self.clean(s).lower()
+
+ if self.split_letters:
+ s = " ".join(regex.findall(r"\X", s, regex.U))
+
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
+
+ return s
diff --git a/lmms_eval/tasks/librispeech/whisper_normalizer/english.json b/lmms_eval/tasks/librispeech/whisper_normalizer/english.json
new file mode 100644
index 00000000..566e4812
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/whisper_normalizer/english.json
@@ -0,0 +1,1741 @@
+{
+ "accessorise": "accessorize",
+ "accessorised": "accessorized",
+ "accessorises": "accessorizes",
+ "accessorising": "accessorizing",
+ "acclimatisation": "acclimatization",
+ "acclimatise": "acclimatize",
+ "acclimatised": "acclimatized",
+ "acclimatises": "acclimatizes",
+ "acclimatising": "acclimatizing",
+ "accoutrements": "accouterments",
+ "aeon": "eon",
+ "aeons": "eons",
+ "aerogramme": "aerogram",
+ "aerogrammes": "aerograms",
+ "aeroplane": "airplane",
+ "aeroplanes": "airplanes",
+ "aesthete": "esthete",
+ "aesthetes": "esthetes",
+ "aesthetic": "esthetic",
+ "aesthetically": "esthetically",
+ "aesthetics": "esthetics",
+ "aetiology": "etiology",
+ "ageing": "aging",
+ "aggrandisement": "aggrandizement",
+ "agonise": "agonize",
+ "agonised": "agonized",
+ "agonises": "agonizes",
+ "agonising": "agonizing",
+ "agonisingly": "agonizingly",
+ "almanack": "almanac",
+ "almanacks": "almanacs",
+ "aluminium": "aluminum",
+ "amortisable": "amortizable",
+ "amortisation": "amortization",
+ "amortisations": "amortizations",
+ "amortise": "amortize",
+ "amortised": "amortized",
+ "amortises": "amortizes",
+ "amortising": "amortizing",
+ "amphitheatre": "amphitheater",
+ "amphitheatres": "amphitheaters",
+ "anaemia": "anemia",
+ "anaemic": "anemic",
+ "anaesthesia": "anesthesia",
+ "anaesthetic": "anesthetic",
+ "anaesthetics": "anesthetics",
+ "anaesthetise": "anesthetize",
+ "anaesthetised": "anesthetized",
+ "anaesthetises": "anesthetizes",
+ "anaesthetising": "anesthetizing",
+ "anaesthetist": "anesthetist",
+ "anaesthetists": "anesthetists",
+ "anaesthetize": "anesthetize",
+ "anaesthetized": "anesthetized",
+ "anaesthetizes": "anesthetizes",
+ "anaesthetizing": "anesthetizing",
+ "analogue": "analog",
+ "analogues": "analogs",
+ "analyse": "analyze",
+ "analysed": "analyzed",
+ "analyses": "analyzes",
+ "analysing": "analyzing",
+ "anglicise": "anglicize",
+ "anglicised": "anglicized",
+ "anglicises": "anglicizes",
+ "anglicising": "anglicizing",
+ "annualised": "annualized",
+ "antagonise": "antagonize",
+ "antagonised": "antagonized",
+ "antagonises": "antagonizes",
+ "antagonising": "antagonizing",
+ "apologise": "apologize",
+ "apologised": "apologized",
+ "apologises": "apologizes",
+ "apologising": "apologizing",
+ "appal": "appall",
+ "appals": "appalls",
+ "appetiser": "appetizer",
+ "appetisers": "appetizers",
+ "appetising": "appetizing",
+ "appetisingly": "appetizingly",
+ "arbour": "arbor",
+ "arbours": "arbors",
+ "archeological": "archaeological",
+ "archaeologically": "archeologically",
+ "archaeologist": "archeologist",
+ "archaeologists": "archeologists",
+ "archaeology": "archeology",
+ "ardour": "ardor",
+ "armour": "armor",
+ "armoured": "armored",
+ "armourer": "armorer",
+ "armourers": "armorers",
+ "armouries": "armories",
+ "armoury": "armory",
+ "artefact": "artifact",
+ "artefacts": "artifacts",
+ "authorise": "authorize",
+ "authorised": "authorized",
+ "authorises": "authorizes",
+ "authorising": "authorizing",
+ "axe": "ax",
+ "backpedalled": "backpedaled",
+ "backpedalling": "backpedaling",
+ "bannister": "banister",
+ "bannisters": "banisters",
+ "baptise": "baptize",
+ "baptised": "baptized",
+ "baptises": "baptizes",
+ "baptising": "baptizing",
+ "bastardise": "bastardize",
+ "bastardised": "bastardized",
+ "bastardises": "bastardizes",
+ "bastardising": "bastardizing",
+ "battleax": "battleaxe",
+ "baulk": "balk",
+ "baulked": "balked",
+ "baulking": "balking",
+ "baulks": "balks",
+ "bedevilled": "bedeviled",
+ "bedevilling": "bedeviling",
+ "behaviour": "behavior",
+ "behavioural": "behavioral",
+ "behaviourism": "behaviorism",
+ "behaviourist": "behaviorist",
+ "behaviourists": "behaviorists",
+ "behaviours": "behaviors",
+ "behove": "behoove",
+ "behoved": "behooved",
+ "behoves": "behooves",
+ "bejewelled": "bejeweled",
+ "belabour": "belabor",
+ "belaboured": "belabored",
+ "belabouring": "belaboring",
+ "belabours": "belabors",
+ "bevelled": "beveled",
+ "bevvies": "bevies",
+ "bevvy": "bevy",
+ "biassed": "biased",
+ "biassing": "biasing",
+ "bingeing": "binging",
+ "bougainvillaea": "bougainvillea",
+ "bougainvillaeas": "bougainvilleas",
+ "bowdlerise": "bowdlerize",
+ "bowdlerised": "bowdlerized",
+ "bowdlerises": "bowdlerizes",
+ "bowdlerising": "bowdlerizing",
+ "breathalyse": "breathalyze",
+ "breathalysed": "breathalyzed",
+ "breathalyser": "breathalyzer",
+ "breathalysers": "breathalyzers",
+ "breathalyses": "breathalyzes",
+ "breathalysing": "breathalyzing",
+ "brutalise": "brutalize",
+ "brutalised": "brutalized",
+ "brutalises": "brutalizes",
+ "brutalising": "brutalizing",
+ "busses": "buses",
+ "bussing": "busing",
+ "caesarean": "cesarean",
+ "caesareans": "cesareans",
+ "calibre": "caliber",
+ "calibres": "calibers",
+ "calliper": "caliper",
+ "callipers": "calipers",
+ "callisthenics": "calisthenics",
+ "canalise": "canalize",
+ "canalised": "canalized",
+ "canalises": "canalizes",
+ "canalising": "canalizing",
+ "cancelation": "cancellation",
+ "cancelations": "cancellations",
+ "cancelled": "canceled",
+ "cancelling": "canceling",
+ "candour": "candor",
+ "cannibalise": "cannibalize",
+ "cannibalised": "cannibalized",
+ "cannibalises": "cannibalizes",
+ "cannibalising": "cannibalizing",
+ "canonise": "canonize",
+ "canonised": "canonized",
+ "canonises": "canonizes",
+ "canonising": "canonizing",
+ "capitalise": "capitalize",
+ "capitalised": "capitalized",
+ "capitalises": "capitalizes",
+ "capitalising": "capitalizing",
+ "caramelise": "caramelize",
+ "caramelised": "caramelized",
+ "caramelises": "caramelizes",
+ "caramelising": "caramelizing",
+ "carbonise": "carbonize",
+ "carbonised": "carbonized",
+ "carbonises": "carbonizes",
+ "carbonising": "carbonizing",
+ "carolled": "caroled",
+ "carolling": "caroling",
+ "catalogue": "catalog",
+ "catalogued": "cataloged",
+ "catalogues": "catalogs",
+ "cataloguing": "cataloging",
+ "catalyse": "catalyze",
+ "catalysed": "catalyzed",
+ "catalyses": "catalyzes",
+ "catalysing": "catalyzing",
+ "categorise": "categorize",
+ "categorised": "categorized",
+ "categorises": "categorizes",
+ "categorising": "categorizing",
+ "cauterise": "cauterize",
+ "cauterised": "cauterized",
+ "cauterises": "cauterizes",
+ "cauterising": "cauterizing",
+ "cavilled": "caviled",
+ "cavilling": "caviling",
+ "centigramme": "centigram",
+ "centigrammes": "centigrams",
+ "centilitre": "centiliter",
+ "centilitres": "centiliters",
+ "centimetre": "centimeter",
+ "centimetres": "centimeters",
+ "centralise": "centralize",
+ "centralised": "centralized",
+ "centralises": "centralizes",
+ "centralising": "centralizing",
+ "centre": "center",
+ "centred": "centered",
+ "centrefold": "centerfold",
+ "centrefolds": "centerfolds",
+ "centrepiece": "centerpiece",
+ "centrepieces": "centerpieces",
+ "centres": "centers",
+ "channelled": "channeled",
+ "channelling": "channeling",
+ "characterise": "characterize",
+ "characterised": "characterized",
+ "characterises": "characterizes",
+ "characterising": "characterizing",
+ "cheque": "check",
+ "chequebook": "checkbook",
+ "chequebooks": "checkbooks",
+ "chequered": "checkered",
+ "cheques": "checks",
+ "chilli": "chili",
+ "chimaera": "chimera",
+ "chimaeras": "chimeras",
+ "chiselled": "chiseled",
+ "chiselling": "chiseling",
+ "circularise": "circularize",
+ "circularised": "circularized",
+ "circularises": "circularizes",
+ "circularising": "circularizing",
+ "civilise": "civilize",
+ "civilised": "civilized",
+ "civilises": "civilizes",
+ "civilising": "civilizing",
+ "clamour": "clamor",
+ "clamoured": "clamored",
+ "clamouring": "clamoring",
+ "clamours": "clamors",
+ "clangour": "clangor",
+ "clarinettist": "clarinetist",
+ "clarinettists": "clarinetists",
+ "collectivise": "collectivize",
+ "collectivised": "collectivized",
+ "collectivises": "collectivizes",
+ "collectivising": "collectivizing",
+ "colonisation": "colonization",
+ "colonise": "colonize",
+ "colonised": "colonized",
+ "coloniser": "colonizer",
+ "colonisers": "colonizers",
+ "colonises": "colonizes",
+ "colonising": "colonizing",
+ "colour": "color",
+ "colourant": "colorant",
+ "colourants": "colorants",
+ "coloured": "colored",
+ "coloureds": "coloreds",
+ "colourful": "colorful",
+ "colourfully": "colorfully",
+ "colouring": "coloring",
+ "colourize": "colorize",
+ "colourized": "colorized",
+ "colourizes": "colorizes",
+ "colourizing": "colorizing",
+ "colourless": "colorless",
+ "colours": "colors",
+ "commercialise": "commercialize",
+ "commercialised": "commercialized",
+ "commercialises": "commercializes",
+ "commercialising": "commercializing",
+ "compartmentalise": "compartmentalize",
+ "compartmentalised": "compartmentalized",
+ "compartmentalises": "compartmentalizes",
+ "compartmentalising": "compartmentalizing",
+ "computerise": "computerize",
+ "computerised": "computerized",
+ "computerises": "computerizes",
+ "computerising": "computerizing",
+ "conceptualise": "conceptualize",
+ "conceptualised": "conceptualized",
+ "conceptualises": "conceptualizes",
+ "conceptualising": "conceptualizing",
+ "connexion": "connection",
+ "connexions": "connections",
+ "contextualise": "contextualize",
+ "contextualised": "contextualized",
+ "contextualises": "contextualizes",
+ "contextualising": "contextualizing",
+ "cosier": "cozier",
+ "cosies": "cozies",
+ "cosiest": "coziest",
+ "cosily": "cozily",
+ "cosiness": "coziness",
+ "cosy": "cozy",
+ "councillor": "councilor",
+ "councillors": "councilors",
+ "counselled": "counseled",
+ "counselling": "counseling",
+ "counsellor": "counselor",
+ "counsellors": "counselors",
+ "crenelated": "crenellated",
+ "criminalise": "criminalize",
+ "criminalised": "criminalized",
+ "criminalises": "criminalizes",
+ "criminalising": "criminalizing",
+ "criticise": "criticize",
+ "criticised": "criticized",
+ "criticises": "criticizes",
+ "criticising": "criticizing",
+ "crueller": "crueler",
+ "cruellest": "cruelest",
+ "crystallisation": "crystallization",
+ "crystallise": "crystallize",
+ "crystallised": "crystallized",
+ "crystallises": "crystallizes",
+ "crystallising": "crystallizing",
+ "cudgelled": "cudgeled",
+ "cudgelling": "cudgeling",
+ "customise": "customize",
+ "customised": "customized",
+ "customises": "customizes",
+ "customising": "customizing",
+ "cypher": "cipher",
+ "cyphers": "ciphers",
+ "decentralisation": "decentralization",
+ "decentralise": "decentralize",
+ "decentralised": "decentralized",
+ "decentralises": "decentralizes",
+ "decentralising": "decentralizing",
+ "decriminalisation": "decriminalization",
+ "decriminalise": "decriminalize",
+ "decriminalised": "decriminalized",
+ "decriminalises": "decriminalizes",
+ "decriminalising": "decriminalizing",
+ "defence": "defense",
+ "defenceless": "defenseless",
+ "defences": "defenses",
+ "dehumanisation": "dehumanization",
+ "dehumanise": "dehumanize",
+ "dehumanised": "dehumanized",
+ "dehumanises": "dehumanizes",
+ "dehumanising": "dehumanizing",
+ "demeanour": "demeanor",
+ "demilitarisation": "demilitarization",
+ "demilitarise": "demilitarize",
+ "demilitarised": "demilitarized",
+ "demilitarises": "demilitarizes",
+ "demilitarising": "demilitarizing",
+ "demobilisation": "demobilization",
+ "demobilise": "demobilize",
+ "demobilised": "demobilized",
+ "demobilises": "demobilizes",
+ "demobilising": "demobilizing",
+ "democratisation": "democratization",
+ "democratise": "democratize",
+ "democratised": "democratized",
+ "democratises": "democratizes",
+ "democratising": "democratizing",
+ "demonise": "demonize",
+ "demonised": "demonized",
+ "demonises": "demonizes",
+ "demonising": "demonizing",
+ "demoralisation": "demoralization",
+ "demoralise": "demoralize",
+ "demoralised": "demoralized",
+ "demoralises": "demoralizes",
+ "demoralising": "demoralizing",
+ "denationalisation": "denationalization",
+ "denationalise": "denationalize",
+ "denationalised": "denationalized",
+ "denationalises": "denationalizes",
+ "denationalising": "denationalizing",
+ "deodorise": "deodorize",
+ "deodorised": "deodorized",
+ "deodorises": "deodorizes",
+ "deodorising": "deodorizing",
+ "depersonalise": "depersonalize",
+ "depersonalised": "depersonalized",
+ "depersonalises": "depersonalizes",
+ "depersonalising": "depersonalizing",
+ "deputise": "deputize",
+ "deputised": "deputized",
+ "deputises": "deputizes",
+ "deputising": "deputizing",
+ "desensitisation": "desensitization",
+ "desensitise": "desensitize",
+ "desensitised": "desensitized",
+ "desensitises": "desensitizes",
+ "desensitising": "desensitizing",
+ "destabilisation": "destabilization",
+ "destabilise": "destabilize",
+ "destabilised": "destabilized",
+ "destabilises": "destabilizes",
+ "destabilising": "destabilizing",
+ "dialled": "dialed",
+ "dialling": "dialing",
+ "dialogue": "dialog",
+ "dialogues": "dialogs",
+ "diarrhoea": "diarrhea",
+ "digitise": "digitize",
+ "digitised": "digitized",
+ "digitises": "digitizes",
+ "digitising": "digitizing",
+ "disc": "disk",
+ "discolour": "discolor",
+ "discoloured": "discolored",
+ "discolouring": "discoloring",
+ "discolours": "discolors",
+ "discs": "disks",
+ "disembowelled": "disemboweled",
+ "disembowelling": "disemboweling",
+ "disfavour": "disfavor",
+ "dishevelled": "disheveled",
+ "dishonour": "dishonor",
+ "dishonourable": "dishonorable",
+ "dishonourably": "dishonorably",
+ "dishonoured": "dishonored",
+ "dishonouring": "dishonoring",
+ "dishonours": "dishonors",
+ "disorganisation": "disorganization",
+ "disorganised": "disorganized",
+ "distil": "distill",
+ "distils": "distills",
+ "dramatisation": "dramatization",
+ "dramatisations": "dramatizations",
+ "dramatise": "dramatize",
+ "dramatised": "dramatized",
+ "dramatises": "dramatizes",
+ "dramatising": "dramatizing",
+ "draught": "draft",
+ "draughtboard": "draftboard",
+ "draughtboards": "draftboards",
+ "draughtier": "draftier",
+ "draughtiest": "draftiest",
+ "draughts": "drafts",
+ "draughtsman": "draftsman",
+ "draughtsmanship": "draftsmanship",
+ "draughtsmen": "draftsmen",
+ "draughtswoman": "draftswoman",
+ "draughtswomen": "draftswomen",
+ "draughty": "drafty",
+ "drivelled": "driveled",
+ "drivelling": "driveling",
+ "duelled": "dueled",
+ "duelling": "dueling",
+ "economise": "economize",
+ "economised": "economized",
+ "economises": "economizes",
+ "economising": "economizing",
+ "edoema": "edema",
+ "editorialise": "editorialize",
+ "editorialised": "editorialized",
+ "editorialises": "editorializes",
+ "editorialising": "editorializing",
+ "empathise": "empathize",
+ "empathised": "empathized",
+ "empathises": "empathizes",
+ "empathising": "empathizing",
+ "emphasise": "emphasize",
+ "emphasised": "emphasized",
+ "emphasises": "emphasizes",
+ "emphasising": "emphasizing",
+ "enamelled": "enameled",
+ "enamelling": "enameling",
+ "enamoured": "enamored",
+ "encyclopaedia": "encyclopedia",
+ "encyclopaedias": "encyclopedias",
+ "encyclopaedic": "encyclopedic",
+ "endeavour": "endeavor",
+ "endeavoured": "endeavored",
+ "endeavouring": "endeavoring",
+ "endeavours": "endeavors",
+ "energise": "energize",
+ "energised": "energized",
+ "energises": "energizes",
+ "energising": "energizing",
+ "enrol": "enroll",
+ "enrols": "enrolls",
+ "enthral": "enthrall",
+ "enthrals": "enthralls",
+ "epaulette": "epaulet",
+ "epaulettes": "epaulets",
+ "epicentre": "epicenter",
+ "epicentres": "epicenters",
+ "epilogue": "epilog",
+ "epilogues": "epilogs",
+ "epitomise": "epitomize",
+ "epitomised": "epitomized",
+ "epitomises": "epitomizes",
+ "epitomising": "epitomizing",
+ "equalisation": "equalization",
+ "equalise": "equalize",
+ "equalised": "equalized",
+ "equaliser": "equalizer",
+ "equalisers": "equalizers",
+ "equalises": "equalizes",
+ "equalising": "equalizing",
+ "eulogise": "eulogize",
+ "eulogised": "eulogized",
+ "eulogises": "eulogizes",
+ "eulogising": "eulogizing",
+ "evangelise": "evangelize",
+ "evangelised": "evangelized",
+ "evangelises": "evangelizes",
+ "evangelising": "evangelizing",
+ "exorcise": "exorcize",
+ "exorcised": "exorcized",
+ "exorcises": "exorcizes",
+ "exorcising": "exorcizing",
+ "extemporisation": "extemporization",
+ "extemporise": "extemporize",
+ "extemporised": "extemporized",
+ "extemporises": "extemporizes",
+ "extemporising": "extemporizing",
+ "externalisation": "externalization",
+ "externalisations": "externalizations",
+ "externalise": "externalize",
+ "externalised": "externalized",
+ "externalises": "externalizes",
+ "externalising": "externalizing",
+ "factorise": "factorize",
+ "factorised": "factorized",
+ "factorises": "factorizes",
+ "factorising": "factorizing",
+ "faecal": "fecal",
+ "faeces": "feces",
+ "familiarisation": "familiarization",
+ "familiarise": "familiarize",
+ "familiarised": "familiarized",
+ "familiarises": "familiarizes",
+ "familiarising": "familiarizing",
+ "fantasise": "fantasize",
+ "fantasised": "fantasized",
+ "fantasises": "fantasizes",
+ "fantasising": "fantasizing",
+ "favour": "favor",
+ "favourable": "favorable",
+ "favourably": "favorably",
+ "favoured": "favored",
+ "favouring": "favoring",
+ "favourite": "favorite",
+ "favourites": "favorites",
+ "favouritism": "favoritism",
+ "favours": "favors",
+ "feminise": "feminize",
+ "feminised": "feminized",
+ "feminises": "feminizes",
+ "feminising": "feminizing",
+ "fertilisation": "fertilization",
+ "fertilise": "fertilize",
+ "fertilised": "fertilized",
+ "fertiliser": "fertilizer",
+ "fertilisers": "fertilizers",
+ "fertilises": "fertilizes",
+ "fertilising": "fertilizing",
+ "fervour": "fervor",
+ "fibre": "fiber",
+ "fibreglass": "fiberglass",
+ "fibres": "fibers",
+ "fictionalisation": "fictionalization",
+ "fictionalisations": "fictionalizations",
+ "fictionalise": "fictionalize",
+ "fictionalised": "fictionalized",
+ "fictionalises": "fictionalizes",
+ "fictionalising": "fictionalizing",
+ "fillet": "filet",
+ "filleted": "fileted",
+ "filleting": "fileting",
+ "fillets": "filets",
+ "finalisation": "finalization",
+ "finalise": "finalize",
+ "finalised": "finalized",
+ "finalises": "finalizes",
+ "finalising": "finalizing",
+ "flautist": "flutist",
+ "flautists": "flutists",
+ "flavour": "flavor",
+ "flavoured": "flavored",
+ "flavouring": "flavoring",
+ "flavourings": "flavorings",
+ "flavourless": "flavorless",
+ "flavours": "flavors",
+ "flavoursome": "flavorsome",
+ "flyer / flier": "flier / flyer",
+ "foetal": "fetal",
+ "foetid": "fetid",
+ "foetus": "fetus",
+ "foetuses": "fetuses",
+ "formalisation": "formalization",
+ "formalise": "formalize",
+ "formalised": "formalized",
+ "formalises": "formalizes",
+ "formalising": "formalizing",
+ "fossilisation": "fossilization",
+ "fossilise": "fossilize",
+ "fossilised": "fossilized",
+ "fossilises": "fossilizes",
+ "fossilising": "fossilizing",
+ "fraternisation": "fraternization",
+ "fraternise": "fraternize",
+ "fraternised": "fraternized",
+ "fraternises": "fraternizes",
+ "fraternising": "fraternizing",
+ "fulfil": "fulfill",
+ "fulfilment": "fulfillment",
+ "fulfils": "fulfills",
+ "funnelled": "funneled",
+ "funnelling": "funneling",
+ "galvanise": "galvanize",
+ "galvanised": "galvanized",
+ "galvanises": "galvanizes",
+ "galvanising": "galvanizing",
+ "gambolled": "gamboled",
+ "gambolling": "gamboling",
+ "gaol": "jail",
+ "gaolbird": "jailbird",
+ "gaolbirds": "jailbirds",
+ "gaolbreak": "jailbreak",
+ "gaolbreaks": "jailbreaks",
+ "gaoled": "jailed",
+ "gaoler": "jailer",
+ "gaolers": "jailers",
+ "gaoling": "jailing",
+ "gaols": "jails",
+ "gasses": "gases",
+ "gage": "gauge",
+ "gaged": "gauged",
+ "gages": "gauges",
+ "gaging": "gauging",
+ "generalisation": "generalization",
+ "generalisations": "generalizations",
+ "generalise": "generalize",
+ "generalised": "generalized",
+ "generalises": "generalizes",
+ "generalising": "generalizing",
+ "ghettoise": "ghettoize",
+ "ghettoised": "ghettoized",
+ "ghettoises": "ghettoizes",
+ "ghettoising": "ghettoizing",
+ "gipsies": "gypsies",
+ "glamorise": "glamorize",
+ "glamorised": "glamorized",
+ "glamorises": "glamorizes",
+ "glamorising": "glamorizing",
+ "glamor": "glamour",
+ "globalisation": "globalization",
+ "globalise": "globalize",
+ "globalised": "globalized",
+ "globalises": "globalizes",
+ "globalising": "globalizing",
+ "glueing": "gluing",
+ "goitre": "goiter",
+ "goitres": "goiters",
+ "gonorrhoea": "gonorrhea",
+ "gramme": "gram",
+ "grammes": "grams",
+ "gravelled": "graveled",
+ "grey": "gray",
+ "greyed": "grayed",
+ "greying": "graying",
+ "greyish": "grayish",
+ "greyness": "grayness",
+ "greys": "grays",
+ "grovelled": "groveled",
+ "grovelling": "groveling",
+ "groyne": "groin",
+ "groynes": "groins",
+ "gruelling": "grueling",
+ "gruellingly": "gruelingly",
+ "gryphon": "griffin",
+ "gryphons": "griffins",
+ "gynaecological": "gynecological",
+ "gynaecologist": "gynecologist",
+ "gynaecologists": "gynecologists",
+ "gynaecology": "gynecology",
+ "haematological": "hematological",
+ "haematologist": "hematologist",
+ "haematologists": "hematologists",
+ "haematology": "hematology",
+ "haemoglobin": "hemoglobin",
+ "haemophilia": "hemophilia",
+ "haemophiliac": "hemophiliac",
+ "haemophiliacs": "hemophiliacs",
+ "haemorrhage": "hemorrhage",
+ "haemorrhaged": "hemorrhaged",
+ "haemorrhages": "hemorrhages",
+ "haemorrhaging": "hemorrhaging",
+ "haemorrhoids": "hemorrhoids",
+ "harbour": "harbor",
+ "harboured": "harbored",
+ "harbouring": "harboring",
+ "harbours": "harbors",
+ "harmonisation": "harmonization",
+ "harmonise": "harmonize",
+ "harmonised": "harmonized",
+ "harmonises": "harmonizes",
+ "harmonising": "harmonizing",
+ "homoeopath": "homeopath",
+ "homoeopathic": "homeopathic",
+ "homoeopaths": "homeopaths",
+ "homoeopathy": "homeopathy",
+ "homogenise": "homogenize",
+ "homogenised": "homogenized",
+ "homogenises": "homogenizes",
+ "homogenising": "homogenizing",
+ "honour": "honor",
+ "honourable": "honorable",
+ "honourably": "honorably",
+ "honoured": "honored",
+ "honouring": "honoring",
+ "honours": "honors",
+ "hospitalisation": "hospitalization",
+ "hospitalise": "hospitalize",
+ "hospitalised": "hospitalized",
+ "hospitalises": "hospitalizes",
+ "hospitalising": "hospitalizing",
+ "humanise": "humanize",
+ "humanised": "humanized",
+ "humanises": "humanizes",
+ "humanising": "humanizing",
+ "humour": "humor",
+ "humoured": "humored",
+ "humouring": "humoring",
+ "humourless": "humorless",
+ "humours": "humors",
+ "hybridise": "hybridize",
+ "hybridised": "hybridized",
+ "hybridises": "hybridizes",
+ "hybridising": "hybridizing",
+ "hypnotise": "hypnotize",
+ "hypnotised": "hypnotized",
+ "hypnotises": "hypnotizes",
+ "hypnotising": "hypnotizing",
+ "hypothesise": "hypothesize",
+ "hypothesised": "hypothesized",
+ "hypothesises": "hypothesizes",
+ "hypothesising": "hypothesizing",
+ "idealisation": "idealization",
+ "idealise": "idealize",
+ "idealised": "idealized",
+ "idealises": "idealizes",
+ "idealising": "idealizing",
+ "idolise": "idolize",
+ "idolised": "idolized",
+ "idolises": "idolizes",
+ "idolising": "idolizing",
+ "immobilisation": "immobilization",
+ "immobilise": "immobilize",
+ "immobilised": "immobilized",
+ "immobiliser": "immobilizer",
+ "immobilisers": "immobilizers",
+ "immobilises": "immobilizes",
+ "immobilising": "immobilizing",
+ "immortalise": "immortalize",
+ "immortalised": "immortalized",
+ "immortalises": "immortalizes",
+ "immortalising": "immortalizing",
+ "immunisation": "immunization",
+ "immunise": "immunize",
+ "immunised": "immunized",
+ "immunises": "immunizes",
+ "immunising": "immunizing",
+ "impanelled": "impaneled",
+ "impanelling": "impaneling",
+ "imperilled": "imperiled",
+ "imperilling": "imperiling",
+ "individualise": "individualize",
+ "individualised": "individualized",
+ "individualises": "individualizes",
+ "individualising": "individualizing",
+ "industrialise": "industrialize",
+ "industrialised": "industrialized",
+ "industrialises": "industrializes",
+ "industrialising": "industrializing",
+ "inflexion": "inflection",
+ "inflexions": "inflections",
+ "initialise": "initialize",
+ "initialised": "initialized",
+ "initialises": "initializes",
+ "initialising": "initializing",
+ "initialled": "initialed",
+ "initialling": "initialing",
+ "instal": "install",
+ "instalment": "installment",
+ "instalments": "installments",
+ "instals": "installs",
+ "instil": "instill",
+ "instils": "instills",
+ "institutionalisation": "institutionalization",
+ "institutionalise": "institutionalize",
+ "institutionalised": "institutionalized",
+ "institutionalises": "institutionalizes",
+ "institutionalising": "institutionalizing",
+ "intellectualise": "intellectualize",
+ "intellectualised": "intellectualized",
+ "intellectualises": "intellectualizes",
+ "intellectualising": "intellectualizing",
+ "internalisation": "internalization",
+ "internalise": "internalize",
+ "internalised": "internalized",
+ "internalises": "internalizes",
+ "internalising": "internalizing",
+ "internationalisation": "internationalization",
+ "internationalise": "internationalize",
+ "internationalised": "internationalized",
+ "internationalises": "internationalizes",
+ "internationalising": "internationalizing",
+ "ionisation": "ionization",
+ "ionise": "ionize",
+ "ionised": "ionized",
+ "ioniser": "ionizer",
+ "ionisers": "ionizers",
+ "ionises": "ionizes",
+ "ionising": "ionizing",
+ "italicise": "italicize",
+ "italicised": "italicized",
+ "italicises": "italicizes",
+ "italicising": "italicizing",
+ "itemise": "itemize",
+ "itemised": "itemized",
+ "itemises": "itemizes",
+ "itemising": "itemizing",
+ "jeopardise": "jeopardize",
+ "jeopardised": "jeopardized",
+ "jeopardises": "jeopardizes",
+ "jeopardising": "jeopardizing",
+ "jewelled": "jeweled",
+ "jeweller": "jeweler",
+ "jewellers": "jewelers",
+ "jewellery": "jewelry",
+ "judgement": "judgment",
+ "kilogramme": "kilogram",
+ "kilogrammes": "kilograms",
+ "kilometre": "kilometer",
+ "kilometres": "kilometers",
+ "labelled": "labeled",
+ "labelling": "labeling",
+ "labour": "labor",
+ "laboured": "labored",
+ "labourer": "laborer",
+ "labourers": "laborers",
+ "labouring": "laboring",
+ "labours": "labors",
+ "lacklustre": "lackluster",
+ "legalisation": "legalization",
+ "legalise": "legalize",
+ "legalised": "legalized",
+ "legalises": "legalizes",
+ "legalising": "legalizing",
+ "legitimise": "legitimize",
+ "legitimised": "legitimized",
+ "legitimises": "legitimizes",
+ "legitimising": "legitimizing",
+ "leukaemia": "leukemia",
+ "levelled": "leveled",
+ "leveller": "leveler",
+ "levellers": "levelers",
+ "levelling": "leveling",
+ "libelled": "libeled",
+ "libelling": "libeling",
+ "libellous": "libelous",
+ "liberalisation": "liberalization",
+ "liberalise": "liberalize",
+ "liberalised": "liberalized",
+ "liberalises": "liberalizes",
+ "liberalising": "liberalizing",
+ "licence": "license",
+ "licenced": "licensed",
+ "licences": "licenses",
+ "licencing": "licensing",
+ "likeable": "likable",
+ "lionisation": "lionization",
+ "lionise": "lionize",
+ "lionised": "lionized",
+ "lionises": "lionizes",
+ "lionising": "lionizing",
+ "liquidise": "liquidize",
+ "liquidised": "liquidized",
+ "liquidiser": "liquidizer",
+ "liquidisers": "liquidizers",
+ "liquidises": "liquidizes",
+ "liquidising": "liquidizing",
+ "litre": "liter",
+ "litres": "liters",
+ "localise": "localize",
+ "localised": "localized",
+ "localises": "localizes",
+ "localising": "localizing",
+ "louvre": "louver",
+ "louvred": "louvered",
+ "louvres": "louvers",
+ "lustre": "luster",
+ "magnetise": "magnetize",
+ "magnetised": "magnetized",
+ "magnetises": "magnetizes",
+ "magnetising": "magnetizing",
+ "manoeuvrability": "maneuverability",
+ "manoeuvrable": "maneuverable",
+ "manoeuvre": "maneuver",
+ "manoeuvred": "maneuvered",
+ "manoeuvres": "maneuvers",
+ "manoeuvring": "maneuvering",
+ "manoeuvrings": "maneuverings",
+ "marginalisation": "marginalization",
+ "marginalise": "marginalize",
+ "marginalised": "marginalized",
+ "marginalises": "marginalizes",
+ "marginalising": "marginalizing",
+ "marshalled": "marshaled",
+ "marshalling": "marshaling",
+ "marvelled": "marveled",
+ "marvelling": "marveling",
+ "marvellous": "marvelous",
+ "marvellously": "marvelously",
+ "materialisation": "materialization",
+ "materialise": "materialize",
+ "materialised": "materialized",
+ "materialises": "materializes",
+ "materialising": "materializing",
+ "maximisation": "maximization",
+ "maximise": "maximize",
+ "maximised": "maximized",
+ "maximises": "maximizes",
+ "maximising": "maximizing",
+ "meagre": "meager",
+ "mechanisation": "mechanization",
+ "mechanise": "mechanize",
+ "mechanised": "mechanized",
+ "mechanises": "mechanizes",
+ "mechanising": "mechanizing",
+ "mediaeval": "medieval",
+ "memorialise": "memorialize",
+ "memorialised": "memorialized",
+ "memorialises": "memorializes",
+ "memorialising": "memorializing",
+ "memorise": "memorize",
+ "memorised": "memorized",
+ "memorises": "memorizes",
+ "memorising": "memorizing",
+ "mesmerise": "mesmerize",
+ "mesmerised": "mesmerized",
+ "mesmerises": "mesmerizes",
+ "mesmerising": "mesmerizing",
+ "metabolise": "metabolize",
+ "metabolised": "metabolized",
+ "metabolises": "metabolizes",
+ "metabolising": "metabolizing",
+ "metre": "meter",
+ "metres": "meters",
+ "micrometre": "micrometer",
+ "micrometres": "micrometers",
+ "militarise": "militarize",
+ "militarised": "militarized",
+ "militarises": "militarizes",
+ "militarising": "militarizing",
+ "milligramme": "milligram",
+ "milligrammes": "milligrams",
+ "millilitre": "milliliter",
+ "millilitres": "milliliters",
+ "millimetre": "millimeter",
+ "millimetres": "millimeters",
+ "miniaturisation": "miniaturization",
+ "miniaturise": "miniaturize",
+ "miniaturised": "miniaturized",
+ "miniaturises": "miniaturizes",
+ "miniaturising": "miniaturizing",
+ "minibusses": "minibuses",
+ "minimise": "minimize",
+ "minimised": "minimized",
+ "minimises": "minimizes",
+ "minimising": "minimizing",
+ "misbehaviour": "misbehavior",
+ "misdemeanour": "misdemeanor",
+ "misdemeanours": "misdemeanors",
+ "misspelt": "misspelled",
+ "mitre": "miter",
+ "mitres": "miters",
+ "mobilisation": "mobilization",
+ "mobilise": "mobilize",
+ "mobilised": "mobilized",
+ "mobilises": "mobilizes",
+ "mobilising": "mobilizing",
+ "modelled": "modeled",
+ "modeller": "modeler",
+ "modellers": "modelers",
+ "modelling": "modeling",
+ "modernise": "modernize",
+ "modernised": "modernized",
+ "modernises": "modernizes",
+ "modernising": "modernizing",
+ "moisturise": "moisturize",
+ "moisturised": "moisturized",
+ "moisturiser": "moisturizer",
+ "moisturisers": "moisturizers",
+ "moisturises": "moisturizes",
+ "moisturising": "moisturizing",
+ "monologue": "monolog",
+ "monologues": "monologs",
+ "monopolisation": "monopolization",
+ "monopolise": "monopolize",
+ "monopolised": "monopolized",
+ "monopolises": "monopolizes",
+ "monopolising": "monopolizing",
+ "moralise": "moralize",
+ "moralised": "moralized",
+ "moralises": "moralizes",
+ "moralising": "moralizing",
+ "motorised": "motorized",
+ "mould": "mold",
+ "moulded": "molded",
+ "moulder": "molder",
+ "mouldered": "moldered",
+ "mouldering": "moldering",
+ "moulders": "molders",
+ "mouldier": "moldier",
+ "mouldiest": "moldiest",
+ "moulding": "molding",
+ "mouldings": "moldings",
+ "moulds": "molds",
+ "mouldy": "moldy",
+ "moult": "molt",
+ "moulted": "molted",
+ "moulting": "molting",
+ "moults": "molts",
+ "moustache": "mustache",
+ "moustached": "mustached",
+ "moustaches": "mustaches",
+ "moustachioed": "mustachioed",
+ "multicoloured": "multicolored",
+ "nationalisation": "nationalization",
+ "nationalisations": "nationalizations",
+ "nationalise": "nationalize",
+ "nationalised": "nationalized",
+ "nationalises": "nationalizes",
+ "nationalising": "nationalizing",
+ "naturalisation": "naturalization",
+ "naturalise": "naturalize",
+ "naturalised": "naturalized",
+ "naturalises": "naturalizes",
+ "naturalising": "naturalizing",
+ "neighbour": "neighbor",
+ "neighbourhood": "neighborhood",
+ "neighbourhoods": "neighborhoods",
+ "neighbouring": "neighboring",
+ "neighbourliness": "neighborliness",
+ "neighbourly": "neighborly",
+ "neighbours": "neighbors",
+ "neutralisation": "neutralization",
+ "neutralise": "neutralize",
+ "neutralised": "neutralized",
+ "neutralises": "neutralizes",
+ "neutralising": "neutralizing",
+ "normalisation": "normalization",
+ "normalise": "normalize",
+ "normalised": "normalized",
+ "normalises": "normalizes",
+ "normalising": "normalizing",
+ "odour": "odor",
+ "odourless": "odorless",
+ "odours": "odors",
+ "oesophagus": "esophagus",
+ "oesophaguses": "esophaguses",
+ "oestrogen": "estrogen",
+ "offence": "offense",
+ "offences": "offenses",
+ "omelette": "omelet",
+ "omelettes": "omelets",
+ "optimise": "optimize",
+ "optimised": "optimized",
+ "optimises": "optimizes",
+ "optimising": "optimizing",
+ "organisation": "organization",
+ "organisational": "organizational",
+ "organisations": "organizations",
+ "organise": "organize",
+ "organised": "organized",
+ "organiser": "organizer",
+ "organisers": "organizers",
+ "organises": "organizes",
+ "organising": "organizing",
+ "orthopaedic": "orthopedic",
+ "orthopaedics": "orthopedics",
+ "ostracise": "ostracize",
+ "ostracised": "ostracized",
+ "ostracises": "ostracizes",
+ "ostracising": "ostracizing",
+ "outmanoeuvre": "outmaneuver",
+ "outmanoeuvred": "outmaneuvered",
+ "outmanoeuvres": "outmaneuvers",
+ "outmanoeuvring": "outmaneuvering",
+ "overemphasise": "overemphasize",
+ "overemphasised": "overemphasized",
+ "overemphasises": "overemphasizes",
+ "overemphasising": "overemphasizing",
+ "oxidisation": "oxidization",
+ "oxidise": "oxidize",
+ "oxidised": "oxidized",
+ "oxidises": "oxidizes",
+ "oxidising": "oxidizing",
+ "paederast": "pederast",
+ "paederasts": "pederasts",
+ "paediatric": "pediatric",
+ "paediatrician": "pediatrician",
+ "paediatricians": "pediatricians",
+ "paediatrics": "pediatrics",
+ "paedophile": "pedophile",
+ "paedophiles": "pedophiles",
+ "paedophilia": "pedophilia",
+ "palaeolithic": "paleolithic",
+ "palaeontologist": "paleontologist",
+ "palaeontologists": "paleontologists",
+ "palaeontology": "paleontology",
+ "panelled": "paneled",
+ "panelling": "paneling",
+ "panellist": "panelist",
+ "panellists": "panelists",
+ "paralyse": "paralyze",
+ "paralysed": "paralyzed",
+ "paralyses": "paralyzes",
+ "paralysing": "paralyzing",
+ "parcelled": "parceled",
+ "parcelling": "parceling",
+ "parlour": "parlor",
+ "parlours": "parlors",
+ "particularise": "particularize",
+ "particularised": "particularized",
+ "particularises": "particularizes",
+ "particularising": "particularizing",
+ "passivisation": "passivization",
+ "passivise": "passivize",
+ "passivised": "passivized",
+ "passivises": "passivizes",
+ "passivising": "passivizing",
+ "pasteurisation": "pasteurization",
+ "pasteurise": "pasteurize",
+ "pasteurised": "pasteurized",
+ "pasteurises": "pasteurizes",
+ "pasteurising": "pasteurizing",
+ "patronise": "patronize",
+ "patronised": "patronized",
+ "patronises": "patronizes",
+ "patronising": "patronizing",
+ "patronisingly": "patronizingly",
+ "pedalled": "pedaled",
+ "pedalling": "pedaling",
+ "pedestrianisation": "pedestrianization",
+ "pedestrianise": "pedestrianize",
+ "pedestrianised": "pedestrianized",
+ "pedestrianises": "pedestrianizes",
+ "pedestrianising": "pedestrianizing",
+ "penalise": "penalize",
+ "penalised": "penalized",
+ "penalises": "penalizes",
+ "penalising": "penalizing",
+ "pencilled": "penciled",
+ "pencilling": "penciling",
+ "personalise": "personalize",
+ "personalised": "personalized",
+ "personalises": "personalizes",
+ "personalising": "personalizing",
+ "pharmacopoeia": "pharmacopeia",
+ "pharmacopoeias": "pharmacopeias",
+ "philosophise": "philosophize",
+ "philosophised": "philosophized",
+ "philosophises": "philosophizes",
+ "philosophising": "philosophizing",
+ "philtre": "filter",
+ "philtres": "filters",
+ "phoney": "phony",
+ "plagiarise": "plagiarize",
+ "plagiarised": "plagiarized",
+ "plagiarises": "plagiarizes",
+ "plagiarising": "plagiarizing",
+ "plough": "plow",
+ "ploughed": "plowed",
+ "ploughing": "plowing",
+ "ploughman": "plowman",
+ "ploughmen": "plowmen",
+ "ploughs": "plows",
+ "ploughshare": "plowshare",
+ "ploughshares": "plowshares",
+ "polarisation": "polarization",
+ "polarise": "polarize",
+ "polarised": "polarized",
+ "polarises": "polarizes",
+ "polarising": "polarizing",
+ "politicisation": "politicization",
+ "politicise": "politicize",
+ "politicised": "politicized",
+ "politicises": "politicizes",
+ "politicising": "politicizing",
+ "popularisation": "popularization",
+ "popularise": "popularize",
+ "popularised": "popularized",
+ "popularises": "popularizes",
+ "popularising": "popularizing",
+ "pouffe": "pouf",
+ "pouffes": "poufs",
+ "practise": "practice",
+ "practised": "practiced",
+ "practises": "practices",
+ "practising": "practicing",
+ "praesidium": "presidium",
+ "praesidiums": "presidiums",
+ "pressurisation": "pressurization",
+ "pressurise": "pressurize",
+ "pressurised": "pressurized",
+ "pressurises": "pressurizes",
+ "pressurising": "pressurizing",
+ "pretence": "pretense",
+ "pretences": "pretenses",
+ "primaeval": "primeval",
+ "prioritisation": "prioritization",
+ "prioritise": "prioritize",
+ "prioritised": "prioritized",
+ "prioritises": "prioritizes",
+ "prioritising": "prioritizing",
+ "privatisation": "privatization",
+ "privatisations": "privatizations",
+ "privatise": "privatize",
+ "privatised": "privatized",
+ "privatises": "privatizes",
+ "privatising": "privatizing",
+ "professionalisation": "professionalization",
+ "professionalise": "professionalize",
+ "professionalised": "professionalized",
+ "professionalises": "professionalizes",
+ "professionalising": "professionalizing",
+ "programme": "program",
+ "programmes": "programs",
+ "prologue": "prolog",
+ "prologues": "prologs",
+ "propagandise": "propagandize",
+ "propagandised": "propagandized",
+ "propagandises": "propagandizes",
+ "propagandising": "propagandizing",
+ "proselytise": "proselytize",
+ "proselytised": "proselytized",
+ "proselytiser": "proselytizer",
+ "proselytisers": "proselytizers",
+ "proselytises": "proselytizes",
+ "proselytising": "proselytizing",
+ "psychoanalyse": "psychoanalyze",
+ "psychoanalysed": "psychoanalyzed",
+ "psychoanalyses": "psychoanalyzes",
+ "psychoanalysing": "psychoanalyzing",
+ "publicise": "publicize",
+ "publicised": "publicized",
+ "publicises": "publicizes",
+ "publicising": "publicizing",
+ "pulverisation": "pulverization",
+ "pulverise": "pulverize",
+ "pulverised": "pulverized",
+ "pulverises": "pulverizes",
+ "pulverising": "pulverizing",
+ "pummelled": "pummel",
+ "pummelling": "pummeled",
+ "pyjama": "pajama",
+ "pyjamas": "pajamas",
+ "pzazz": "pizzazz",
+ "quarrelled": "quarreled",
+ "quarrelling": "quarreling",
+ "radicalise": "radicalize",
+ "radicalised": "radicalized",
+ "radicalises": "radicalizes",
+ "radicalising": "radicalizing",
+ "rancour": "rancor",
+ "randomise": "randomize",
+ "randomised": "randomized",
+ "randomises": "randomizes",
+ "randomising": "randomizing",
+ "rationalisation": "rationalization",
+ "rationalisations": "rationalizations",
+ "rationalise": "rationalize",
+ "rationalised": "rationalized",
+ "rationalises": "rationalizes",
+ "rationalising": "rationalizing",
+ "ravelled": "raveled",
+ "ravelling": "raveling",
+ "realisable": "realizable",
+ "realisation": "realization",
+ "realisations": "realizations",
+ "realise": "realize",
+ "realised": "realized",
+ "realises": "realizes",
+ "realising": "realizing",
+ "recognisable": "recognizable",
+ "recognisably": "recognizably",
+ "recognisance": "recognizance",
+ "recognise": "recognize",
+ "recognised": "recognized",
+ "recognises": "recognizes",
+ "recognising": "recognizing",
+ "reconnoitre": "reconnoiter",
+ "reconnoitred": "reconnoitered",
+ "reconnoitres": "reconnoiters",
+ "reconnoitring": "reconnoitering",
+ "refuelled": "refueled",
+ "refuelling": "refueling",
+ "regularisation": "regularization",
+ "regularise": "regularize",
+ "regularised": "regularized",
+ "regularises": "regularizes",
+ "regularising": "regularizing",
+ "remodelled": "remodeled",
+ "remodelling": "remodeling",
+ "remould": "remold",
+ "remoulded": "remolded",
+ "remoulding": "remolding",
+ "remoulds": "remolds",
+ "reorganisation": "reorganization",
+ "reorganisations": "reorganizations",
+ "reorganise": "reorganize",
+ "reorganised": "reorganized",
+ "reorganises": "reorganizes",
+ "reorganising": "reorganizing",
+ "revelled": "reveled",
+ "reveller": "reveler",
+ "revellers": "revelers",
+ "revelling": "reveling",
+ "revitalise": "revitalize",
+ "revitalised": "revitalized",
+ "revitalises": "revitalizes",
+ "revitalising": "revitalizing",
+ "revolutionise": "revolutionize",
+ "revolutionised": "revolutionized",
+ "revolutionises": "revolutionizes",
+ "revolutionising": "revolutionizing",
+ "rhapsodise": "rhapsodize",
+ "rhapsodised": "rhapsodized",
+ "rhapsodises": "rhapsodizes",
+ "rhapsodising": "rhapsodizing",
+ "rigour": "rigor",
+ "rigours": "rigors",
+ "ritualised": "ritualized",
+ "rivalled": "rivaled",
+ "rivalling": "rivaling",
+ "romanticise": "romanticize",
+ "romanticised": "romanticized",
+ "romanticises": "romanticizes",
+ "romanticising": "romanticizing",
+ "rumour": "rumor",
+ "rumoured": "rumored",
+ "rumours": "rumors",
+ "sabre": "saber",
+ "sabres": "sabers",
+ "saltpetre": "saltpeter",
+ "sanitise": "sanitize",
+ "sanitised": "sanitized",
+ "sanitises": "sanitizes",
+ "sanitising": "sanitizing",
+ "satirise": "satirize",
+ "satirised": "satirized",
+ "satirises": "satirizes",
+ "satirising": "satirizing",
+ "saviour": "savior",
+ "saviours": "saviors",
+ "savour": "savor",
+ "savoured": "savored",
+ "savouries": "savories",
+ "savouring": "savoring",
+ "savours": "savors",
+ "savoury": "savory",
+ "scandalise": "scandalize",
+ "scandalised": "scandalized",
+ "scandalises": "scandalizes",
+ "scandalising": "scandalizing",
+ "sceptic": "skeptic",
+ "sceptical": "skeptical",
+ "sceptically": "skeptically",
+ "scepticism": "skepticism",
+ "sceptics": "skeptics",
+ "sceptre": "scepter",
+ "sceptres": "scepters",
+ "scrutinise": "scrutinize",
+ "scrutinised": "scrutinized",
+ "scrutinises": "scrutinizes",
+ "scrutinising": "scrutinizing",
+ "secularisation": "secularization",
+ "secularise": "secularize",
+ "secularised": "secularized",
+ "secularises": "secularizes",
+ "secularising": "secularizing",
+ "sensationalise": "sensationalize",
+ "sensationalised": "sensationalized",
+ "sensationalises": "sensationalizes",
+ "sensationalising": "sensationalizing",
+ "sensitise": "sensitize",
+ "sensitised": "sensitized",
+ "sensitises": "sensitizes",
+ "sensitising": "sensitizing",
+ "sentimentalise": "sentimentalize",
+ "sentimentalised": "sentimentalized",
+ "sentimentalises": "sentimentalizes",
+ "sentimentalising": "sentimentalizing",
+ "sepulchre": "sepulcher",
+ "sepulchres": "sepulchers",
+ "serialisation": "serialization",
+ "serialisations": "serializations",
+ "serialise": "serialize",
+ "serialised": "serialized",
+ "serialises": "serializes",
+ "serialising": "serializing",
+ "sermonise": "sermonize",
+ "sermonised": "sermonized",
+ "sermonises": "sermonizes",
+ "sermonising": "sermonizing",
+ "sheikh": "sheik",
+ "shovelled": "shoveled",
+ "shovelling": "shoveling",
+ "shrivelled": "shriveled",
+ "shrivelling": "shriveling",
+ "signalise": "signalize",
+ "signalised": "signalized",
+ "signalises": "signalizes",
+ "signalising": "signalizing",
+ "signalled": "signaled",
+ "signalling": "signaling",
+ "smoulder": "smolder",
+ "smouldered": "smoldered",
+ "smouldering": "smoldering",
+ "smoulders": "smolders",
+ "snivelled": "sniveled",
+ "snivelling": "sniveling",
+ "snorkelled": "snorkeled",
+ "snorkelling": "snorkeling",
+ "snowplough": "snowplow",
+ "snowploughs": "snowplow",
+ "socialisation": "socialization",
+ "socialise": "socialize",
+ "socialised": "socialized",
+ "socialises": "socializes",
+ "socialising": "socializing",
+ "sodomise": "sodomize",
+ "sodomised": "sodomized",
+ "sodomises": "sodomizes",
+ "sodomising": "sodomizing",
+ "solemnise": "solemnize",
+ "solemnised": "solemnized",
+ "solemnises": "solemnizes",
+ "solemnising": "solemnizing",
+ "sombre": "somber",
+ "specialisation": "specialization",
+ "specialisations": "specializations",
+ "specialise": "specialize",
+ "specialised": "specialized",
+ "specialises": "specializes",
+ "specialising": "specializing",
+ "spectre": "specter",
+ "spectres": "specters",
+ "spiralled": "spiraled",
+ "spiralling": "spiraling",
+ "splendour": "splendor",
+ "splendours": "splendors",
+ "squirrelled": "squirreled",
+ "squirrelling": "squirreling",
+ "stabilisation": "stabilization",
+ "stabilise": "stabilize",
+ "stabilised": "stabilized",
+ "stabiliser": "stabilizer",
+ "stabilisers": "stabilizers",
+ "stabilises": "stabilizes",
+ "stabilising": "stabilizing",
+ "standardisation": "standardization",
+ "standardise": "standardize",
+ "standardised": "standardized",
+ "standardises": "standardizes",
+ "standardising": "standardizing",
+ "stencilled": "stenciled",
+ "stencilling": "stenciling",
+ "sterilisation": "sterilization",
+ "sterilisations": "sterilizations",
+ "sterilise": "sterilize",
+ "sterilised": "sterilized",
+ "steriliser": "sterilizer",
+ "sterilisers": "sterilizers",
+ "sterilises": "sterilizes",
+ "sterilising": "sterilizing",
+ "stigmatisation": "stigmatization",
+ "stigmatise": "stigmatize",
+ "stigmatised": "stigmatized",
+ "stigmatises": "stigmatizes",
+ "stigmatising": "stigmatizing",
+ "storey": "story",
+ "storeys": "stories",
+ "subsidisation": "subsidization",
+ "subsidise": "subsidize",
+ "subsidised": "subsidized",
+ "subsidiser": "subsidizer",
+ "subsidisers": "subsidizers",
+ "subsidises": "subsidizes",
+ "subsidising": "subsidizing",
+ "succour": "succor",
+ "succoured": "succored",
+ "succouring": "succoring",
+ "succours": "succors",
+ "sulphate": "sulfate",
+ "sulphates": "sulfates",
+ "sulphide": "sulfide",
+ "sulphides": "sulfides",
+ "sulphur": "sulfur",
+ "sulphurous": "sulfurous",
+ "summarise": "summarize",
+ "summarised": "summarized",
+ "summarises": "summarizes",
+ "summarising": "summarizing",
+ "swivelled": "swiveled",
+ "swivelling": "swiveling",
+ "symbolise": "symbolize",
+ "symbolised": "symbolized",
+ "symbolises": "symbolizes",
+ "symbolising": "symbolizing",
+ "sympathise": "sympathize",
+ "sympathised": "sympathized",
+ "sympathiser": "sympathizer",
+ "sympathisers": "sympathizers",
+ "sympathises": "sympathizes",
+ "sympathising": "sympathizing",
+ "synchronisation": "synchronization",
+ "synchronise": "synchronize",
+ "synchronised": "synchronized",
+ "synchronises": "synchronizes",
+ "synchronising": "synchronizing",
+ "synthesise": "synthesize",
+ "synthesised": "synthesized",
+ "synthesiser": "synthesizer",
+ "synthesisers": "synthesizers",
+ "synthesises": "synthesizes",
+ "synthesising": "synthesizing",
+ "syphon": "siphon",
+ "syphoned": "siphoned",
+ "syphoning": "siphoning",
+ "syphons": "siphons",
+ "systematisation": "systematization",
+ "systematise": "systematize",
+ "systematised": "systematized",
+ "systematises": "systematizes",
+ "systematising": "systematizing",
+ "tantalise": "tantalize",
+ "tantalised": "tantalized",
+ "tantalises": "tantalizes",
+ "tantalising": "tantalizing",
+ "tantalisingly": "tantalizingly",
+ "tasselled": "tasseled",
+ "technicolour": "technicolor",
+ "temporise": "temporize",
+ "temporised": "temporized",
+ "temporises": "temporizes",
+ "temporising": "temporizing",
+ "tenderise": "tenderize",
+ "tenderised": "tenderized",
+ "tenderises": "tenderizes",
+ "tenderising": "tenderizing",
+ "terrorise": "terrorize",
+ "terrorised": "terrorized",
+ "terrorises": "terrorizes",
+ "terrorising": "terrorizing",
+ "theatre": "theater",
+ "theatregoer": "theatergoer",
+ "theatregoers": "theatergoers",
+ "theatres": "theaters",
+ "theorise": "theorize",
+ "theorised": "theorized",
+ "theorises": "theorizes",
+ "theorising": "theorizing",
+ "tonne": "ton",
+ "tonnes": "tons",
+ "towelled": "toweled",
+ "towelling": "toweling",
+ "toxaemia": "toxemia",
+ "tranquillise": "tranquilize",
+ "tranquillised": "tranquilized",
+ "tranquilliser": "tranquilizer",
+ "tranquillisers": "tranquilizers",
+ "tranquillises": "tranquilizes",
+ "tranquillising": "tranquilizing",
+ "tranquillity": "tranquility",
+ "tranquillize": "tranquilize",
+ "tranquillized": "tranquilized",
+ "tranquillizer": "tranquilizer",
+ "tranquillizers": "tranquilizers",
+ "tranquillizes": "tranquilizes",
+ "tranquillizing": "tranquilizing",
+ "tranquilly": "tranquility",
+ "transistorised": "transistorized",
+ "traumatise": "traumatize",
+ "traumatised": "traumatized",
+ "traumatises": "traumatizes",
+ "traumatising": "traumatizing",
+ "travelled": "traveled",
+ "traveller": "traveler",
+ "travellers": "travelers",
+ "travelling": "traveling",
+ "travelog": "travelogue",
+ "travelogs": "travelogues",
+ "trialled": "trialed",
+ "trialling": "trialing",
+ "tricolour": "tricolor",
+ "tricolours": "tricolors",
+ "trivialise": "trivialize",
+ "trivialised": "trivialized",
+ "trivialises": "trivializes",
+ "trivialising": "trivializing",
+ "tumour": "tumor",
+ "tumours": "tumors",
+ "tunnelled": "tunneled",
+ "tunnelling": "tunneling",
+ "tyrannise": "tyrannize",
+ "tyrannised": "tyrannized",
+ "tyrannises": "tyrannizes",
+ "tyrannising": "tyrannizing",
+ "tyre": "tire",
+ "tyres": "tires",
+ "unauthorised": "unauthorized",
+ "uncivilised": "uncivilized",
+ "underutilised": "underutilized",
+ "unequalled": "unequaled",
+ "unfavourable": "unfavorable",
+ "unfavourably": "unfavorably",
+ "unionisation": "unionization",
+ "unionise": "unionize",
+ "unionised": "unionized",
+ "unionises": "unionizes",
+ "unionising": "unionizing",
+ "unorganised": "unorganized",
+ "unravelled": "unraveled",
+ "unravelling": "unraveling",
+ "unrecognisable": "unrecognizable",
+ "unrecognised": "unrecognized",
+ "unrivalled": "unrivaled",
+ "unsavoury": "unsavory",
+ "untrammelled": "untrammeled",
+ "urbanisation": "urbanization",
+ "urbanise": "urbanize",
+ "urbanised": "urbanized",
+ "urbanises": "urbanizes",
+ "urbanising": "urbanizing",
+ "utilisable": "utilizable",
+ "utilisation": "utilization",
+ "utilise": "utilize",
+ "utilised": "utilized",
+ "utilises": "utilizes",
+ "utilising": "utilizing",
+ "valour": "valor",
+ "vandalise": "vandalize",
+ "vandalised": "vandalized",
+ "vandalises": "vandalizes",
+ "vandalising": "vandalizing",
+ "vaporisation": "vaporization",
+ "vaporise": "vaporize",
+ "vaporised": "vaporized",
+ "vaporises": "vaporizes",
+ "vaporising": "vaporizing",
+ "vapour": "vapor",
+ "vapours": "vapors",
+ "verbalise": "verbalize",
+ "verbalised": "verbalized",
+ "verbalises": "verbalizes",
+ "verbalising": "verbalizing",
+ "victimisation": "victimization",
+ "victimise": "victimize",
+ "victimised": "victimized",
+ "victimises": "victimizes",
+ "victimising": "victimizing",
+ "videodisc": "videodisk",
+ "videodiscs": "videodisks",
+ "vigour": "vigor",
+ "visualisation": "visualization",
+ "visualisations": "visualizations",
+ "visualise": "visualize",
+ "visualised": "visualized",
+ "visualises": "visualizes",
+ "visualising": "visualizing",
+ "vocalisation": "vocalization",
+ "vocalisations": "vocalizations",
+ "vocalise": "vocalize",
+ "vocalised": "vocalized",
+ "vocalises": "vocalizes",
+ "vocalising": "vocalizing",
+ "vulcanised": "vulcanized",
+ "vulgarisation": "vulgarization",
+ "vulgarise": "vulgarize",
+ "vulgarised": "vulgarized",
+ "vulgarises": "vulgarizes",
+ "vulgarising": "vulgarizing",
+ "waggon": "wagon",
+ "waggons": "wagons",
+ "watercolour": "watercolor",
+ "watercolours": "watercolors",
+ "weaselled": "weaseled",
+ "weaselling": "weaseling",
+ "westernisation": "westernization",
+ "westernise": "westernize",
+ "westernised": "westernized",
+ "westernises": "westernizes",
+ "westernising": "westernizing",
+ "womanise": "womanize",
+ "womanised": "womanized",
+ "womaniser": "womanizer",
+ "womanisers": "womanizers",
+ "womanises": "womanizes",
+ "womanising": "womanizing",
+ "woollen": "woolen",
+ "woollens": "woolens",
+ "woollies": "woolies",
+ "woolly": "wooly",
+ "worshipped": "worshiped",
+ "worshipping": "worshiping",
+ "worshipper": "worshiper",
+ "yodelled": "yodeled",
+ "yodelling": "yodeling",
+ "yoghourt": "yogurt",
+ "yoghourts": "yogurts",
+ "yoghurt": "yogurt",
+ "yoghurts": "yogurts",
+ "mhm": "hmm",
+ "mmm": "hmm"
+}
\ No newline at end of file
diff --git a/lmms_eval/tasks/librispeech/whisper_normalizer/english.py b/lmms_eval/tasks/librispeech/whisper_normalizer/english.py
new file mode 100644
index 00000000..7102f2ec
--- /dev/null
+++ b/lmms_eval/tasks/librispeech/whisper_normalizer/english.py
@@ -0,0 +1,529 @@
+import json
+import os
+import re
+from fractions import Fraction
+from typing import Iterator, List, Match, Optional, Union
+
+from more_itertools import windowed # TODO: new package
+
+from .basic import remove_symbols_and_diacritics
+
+
+class EnglishNumberNormalizer:
+ """
+ Convert any spelled-out numbers into arabic numbers, while handling:
+
+ - remove any commas
+ - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
+ - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
+ - spell out `one` and `ones`
+ - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self.zeros = {"o", "oh", "zero"}
+ self.ones = {
+ name: i
+ for i, name in enumerate(
+ [
+ "one",
+ "two",
+ "three",
+ "four",
+ "five",
+ "six",
+ "seven",
+ "eight",
+ "nine",
+ "ten",
+ "eleven",
+ "twelve",
+ "thirteen",
+ "fourteen",
+ "fifteen",
+ "sixteen",
+ "seventeen",
+ "eighteen",
+ "nineteen",
+ ],
+ start=1,
+ )
+ }
+ self.ones_plural = {"sixes" if name == "six" else name + "s": (value, "s") for name, value in self.ones.items()}
+ self.ones_ordinal = {
+ "zeroth": (0, "th"),
+ "first": (1, "st"),
+ "second": (2, "nd"),
+ "third": (3, "rd"),
+ "fifth": (5, "th"),
+ "twelfth": (12, "th"),
+ **{name + ("h" if name.endswith("t") else "th"): (value, "th") for name, value in self.ones.items() if value > 3 and value != 5 and value != 12},
+ }
+ self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
+
+ self.tens = {
+ "twenty": 20,
+ "thirty": 30,
+ "forty": 40,
+ "fifty": 50,
+ "sixty": 60,
+ "seventy": 70,
+ "eighty": 80,
+ "ninety": 90,
+ }
+ self.tens_plural = {name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()}
+ self.tens_ordinal = {name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()}
+ self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
+
+ self.multipliers = {
+ "hundred": 100,
+ "thousand": 1_000,
+ "million": 1_000_000,
+ "billion": 1_000_000_000,
+ "trillion": 1_000_000_000_000,
+ "quadrillion": 1_000_000_000_000_000,
+ "quintillion": 1_000_000_000_000_000_000,
+ "sextillion": 1_000_000_000_000_000_000_000,
+ "septillion": 1_000_000_000_000_000_000_000_000,
+ "octillion": 1_000_000_000_000_000_000_000_000_000,
+ "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
+ "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
+ }
+ self.multipliers_plural = {name + "s": (value, "s") for name, value in self.multipliers.items()}
+ self.multipliers_ordinal = {name + "th": (value, "th") for name, value in self.multipliers.items()}
+ self.multipliers_suffixed = {
+ **self.multipliers_plural,
+ **self.multipliers_ordinal,
+ }
+ self.decimals = {*self.ones, *self.tens, *self.zeros}
+
+ self.preceding_prefixers = {
+ "minus": "-",
+ "negative": "-",
+ "plus": "+",
+ "positive": "+",
+ }
+ self.following_prefixers = {
+ "pound": "ยฃ",
+ "pounds": "ยฃ",
+ "euro": "โฌ",
+ "euros": "โฌ",
+ "dollar": "$",
+ "dollars": "$",
+ "cent": "ยข",
+ "cents": "ยข",
+ }
+ self.prefixes = set(list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()))
+ self.suffixers = {
+ "per": {"cent": "%"},
+ "percent": "%",
+ }
+ self.specials = {"and", "double", "triple", "point"}
+
+ self.words = set(
+ [
+ key
+ for mapping in [
+ self.zeros,
+ self.ones,
+ self.ones_suffixed,
+ self.tens,
+ self.tens_suffixed,
+ self.multipliers,
+ self.multipliers_suffixed,
+ self.preceding_prefixers,
+ self.following_prefixers,
+ self.suffixers,
+ self.specials,
+ ]
+ for key in mapping
+ ]
+ )
+ self.literal_words = {"one", "ones"}
+
+ def process_words(self, words: List[str]) -> Iterator[str]:
+ prefix: Optional[str] = None
+ value: Optional[Union[str, int]] = None
+ skip = False
+
+ def to_fraction(s: str):
+ try:
+ return Fraction(s)
+ except ValueError:
+ return None
+
+ def output(result: Union[str, int]):
+ nonlocal prefix, value
+ result = str(result)
+ if prefix is not None:
+ result = prefix + result
+ value = None
+ prefix = None
+ return result
+
+ if len(words) == 0:
+ return
+
+ for prev, current, next in windowed([None] + words + [None], 3):
+ if skip:
+ skip = False
+ continue
+
+ next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
+ has_prefix = current[0] in self.prefixes
+ current_without_prefix = current[1:] if has_prefix else current
+ if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
+ # arabic numbers (potentially with signs and fractions)
+ f = to_fraction(current_without_prefix)
+ assert f is not None
+ if value is not None:
+ if isinstance(value, str) and value.endswith("."):
+ # concatenate decimals / ip address components
+ value = str(value) + str(current)
+ continue
+ else:
+ yield output(value)
+
+ prefix = current[0] if has_prefix else prefix
+ if f.denominator == 1:
+ value = f.numerator # store integers as int
+ else:
+ value = current_without_prefix
+ elif current not in self.words:
+ # non-numeric words
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current in self.zeros:
+ value = str(value or "") + "0"
+ elif current in self.ones:
+ ones = self.ones[current]
+
+ if value is None:
+ value = ones
+ elif isinstance(value, str) or prev in self.ones:
+ if prev in self.tens and ones < 10: # replace the last zero with the digit
+ assert value[-1] == "0"
+ value = value[:-1] + str(ones)
+ else:
+ value = str(value) + str(ones)
+ elif ones < 10:
+ if value % 10 == 0:
+ value += ones
+ else:
+ value = str(value) + str(ones)
+ else: # eleven to nineteen
+ if value % 100 == 0:
+ value += ones
+ else:
+ value = str(value) + str(ones)
+ elif current in self.ones_suffixed:
+ # ordinal or cardinal; yield the number right away
+ ones, suffix = self.ones_suffixed[current]
+ if value is None:
+ yield output(str(ones) + suffix)
+ elif isinstance(value, str) or prev in self.ones:
+ if prev in self.tens and ones < 10:
+ assert value[-1] == "0"
+ yield output(value[:-1] + str(ones) + suffix)
+ else:
+ yield output(str(value) + str(ones) + suffix)
+ elif ones < 10:
+ if value % 10 == 0:
+ yield output(str(value + ones) + suffix)
+ else:
+ yield output(str(value) + str(ones) + suffix)
+ else: # eleven to nineteen
+ if value % 100 == 0:
+ yield output(str(value + ones) + suffix)
+ else:
+ yield output(str(value) + str(ones) + suffix)
+ value = None
+ elif current in self.tens:
+ tens = self.tens[current]
+ if value is None:
+ value = tens
+ elif isinstance(value, str):
+ value = str(value) + str(tens)
+ else:
+ if value % 100 == 0:
+ value += tens
+ else:
+ value = str(value) + str(tens)
+ elif current in self.tens_suffixed:
+ # ordinal or cardinal; yield the number right away
+ tens, suffix = self.tens_suffixed[current]
+ if value is None:
+ yield output(str(tens) + suffix)
+ elif isinstance(value, str):
+ yield output(str(value) + str(tens) + suffix)
+ else:
+ if value % 100 == 0:
+ yield output(str(value + tens) + suffix)
+ else:
+ yield output(str(value) + str(tens) + suffix)
+ elif current in self.multipliers:
+ multiplier = self.multipliers[current]
+ if value is None:
+ value = multiplier
+ elif isinstance(value, str) or value == 0:
+ f = to_fraction(value)
+ p = f * multiplier if f is not None else None
+ if f is not None and p.denominator == 1:
+ value = p.numerator
+ else:
+ yield output(value)
+ value = multiplier
+ else:
+ before = value // 1000 * 1000
+ residual = value % 1000
+ value = before + residual * multiplier
+ elif current in self.multipliers_suffixed:
+ multiplier, suffix = self.multipliers_suffixed[current]
+ if value is None:
+ yield output(str(multiplier) + suffix)
+ elif isinstance(value, str):
+ f = to_fraction(value)
+ p = f * multiplier if f is not None else None
+ if f is not None and p.denominator == 1:
+ yield output(str(p.numerator) + suffix)
+ else:
+ yield output(value)
+ yield output(str(multiplier) + suffix)
+ else: # int
+ before = value // 1000 * 1000
+ residual = value % 1000
+ value = before + residual * multiplier
+ yield output(str(value) + suffix)
+ value = None
+ elif current in self.preceding_prefixers:
+ # apply prefix (positive, minus, etc.) if it precedes a number
+ if value is not None:
+ yield output(value)
+
+ if next in self.words or next_is_numeric:
+ prefix = self.preceding_prefixers[current]
+ else:
+ yield output(current)
+ elif current in self.following_prefixers:
+ # apply prefix (dollars, cents, etc.) only after a number
+ if value is not None:
+ prefix = self.following_prefixers[current]
+ yield output(value)
+ else:
+ yield output(current)
+ elif current in self.suffixers:
+ # apply suffix symbols (percent -> '%')
+ if value is not None:
+ suffix = self.suffixers[current]
+ if isinstance(suffix, dict):
+ if next in suffix:
+ yield output(str(value) + suffix[next])
+ skip = True
+ else:
+ yield output(value)
+ yield output(current)
+ else:
+ yield output(str(value) + suffix)
+ else:
+ yield output(current)
+ elif current in self.specials:
+ if next not in self.words and not next_is_numeric:
+ # apply special handling only if the next word can be numeric
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current == "and":
+ # ignore "and" after hundreds, thousands, etc.
+ if prev not in self.multipliers:
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current == "double" or current == "triple":
+ if next in self.ones or next in self.zeros:
+ repeats = 2 if current == "double" else 3
+ ones = self.ones.get(next, 0)
+ value = str(value or "") + str(ones) * repeats
+ skip = True
+ else:
+ if value is not None:
+ yield output(value)
+ yield output(current)
+ elif current == "point":
+ if next in self.decimals or next_is_numeric:
+ value = str(value or "") + "."
+ else:
+ # should all have been covered at this point
+ raise ValueError(f"Unexpected token: {current}")
+ else:
+ # all should have been covered at this point
+ raise ValueError(f"Unexpected token: {current}")
+
+ if value is not None:
+ yield output(value)
+
+ def preprocess(self, s: str):
+ # replace " and a half" with " point five"
+ results = []
+
+ segments = re.split(r"\band\s+a\s+half\b", s)
+ for i, segment in enumerate(segments):
+ if len(segment.strip()) == 0:
+ continue
+ if i == len(segments) - 1:
+ results.append(segment)
+ else:
+ results.append(segment)
+ last_word = segment.rsplit(maxsplit=2)[-1]
+ if last_word in self.decimals or last_word in self.multipliers:
+ results.append("point five")
+ else:
+ results.append("and a half")
+
+ s = " ".join(results)
+
+ # put a space at number/letter boundary
+ s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
+ s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
+
+ # but remove spaces which could be a suffix
+ s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
+
+ return s
+
+ def postprocess(self, s: str):
+ def combine_cents(m: Match):
+ try:
+ currency = m.group(1)
+ integer = m.group(2)
+ cents = int(m.group(3))
+ return f"{currency}{integer}.{cents:02d}"
+ except ValueError:
+ return m.string
+
+ def extract_cents(m: Match):
+ try:
+ return f"ยข{int(m.group(1))}"
+ except ValueError:
+ return m.string
+
+ # apply currency postprocessing; "$2 and ยข7" -> "$2.07"
+ s = re.sub(r"([โฌยฃ$])([0-9]+) (?:and )?ยข([0-9]{1,2})\b", combine_cents, s)
+ s = re.sub(r"[โฌยฃ$]0.([0-9]{1,2})\b", extract_cents, s)
+
+ # write "one(s)" instead of "1(s)", just for the readability
+ s = re.sub(r"\b1(s?)\b", r"one\1", s)
+
+ return s
+
+ def __call__(self, s: str):
+ s = self.preprocess(s)
+ s = " ".join(word for word in self.process_words(s.split()) if word is not None)
+ s = self.postprocess(s)
+
+ return s
+
+
+class EnglishSpellingNormalizer:
+ """
+ Applies British-American spelling mappings as listed in [1].
+
+ [1] https://www.tysto.com/uk-us-spelling-list.html
+ """
+
+ def __init__(self):
+ mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
+ self.mapping = json.load(open(mapping_path))
+
+ def __call__(self, s: str):
+ return " ".join(self.mapping.get(word, word) for word in s.split())
+
+
+class EnglishTextNormalizer:
+ def __init__(self):
+ self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
+ self.replacers = {
+ # common contractions
+ r"\bwon't\b": "will not",
+ r"\bcan't\b": "can not",
+ r"\blet's\b": "let us",
+ r"\bain't\b": "aint",
+ r"\by'all\b": "you all",
+ r"\bwanna\b": "want to",
+ r"\bgotta\b": "got to",
+ r"\bgonna\b": "going to",
+ r"\bi'ma\b": "i am going to",
+ r"\bimma\b": "i am going to",
+ r"\bwoulda\b": "would have",
+ r"\bcoulda\b": "could have",
+ r"\bshoulda\b": "should have",
+ r"\bma'am\b": "madam",
+ # contractions in titles/prefixes
+ r"\bmr\b": "mister ",
+ r"\bmrs\b": "missus ",
+ r"\bst\b": "saint ",
+ r"\bdr\b": "doctor ",
+ r"\bprof\b": "professor ",
+ r"\bcapt\b": "captain ",
+ r"\bgov\b": "governor ",
+ r"\bald\b": "alderman ",
+ r"\bgen\b": "general ",
+ r"\bsen\b": "senator ",
+ r"\brep\b": "representative ",
+ r"\bpres\b": "president ",
+ r"\brev\b": "reverend ",
+ r"\bhon\b": "honorable ",
+ r"\basst\b": "assistant ",
+ r"\bassoc\b": "associate ",
+ r"\blt\b": "lieutenant ",
+ r"\bcol\b": "colonel ",
+ r"\bjr\b": "junior ",
+ r"\bsr\b": "senior ",
+ r"\besq\b": "esquire ",
+ # prefect tenses, ideally it should be any past participles, but it's harder..
+ r"'d been\b": " had been",
+ r"'s been\b": " has been",
+ r"'d gone\b": " had gone",
+ r"'s gone\b": " has gone",
+ r"'d done\b": " had done", # "'s done" is ambiguous
+ r"'s got\b": " has got",
+ # general contractions
+ r"n't\b": " not",
+ r"'re\b": " are",
+ r"'s\b": " is",
+ r"'d\b": " would",
+ r"'ll\b": " will",
+ r"'t\b": " not",
+ r"'ve\b": " have",
+ r"'m\b": " am",
+ }
+ self.standardize_numbers = EnglishNumberNormalizer()
+ self.standardize_spellings = EnglishSpellingNormalizer()
+
+ def __call__(self, s: str):
+ s = s.lower()
+
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
+ s = re.sub(self.ignore_patterns, "", s)
+ s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
+
+ for pattern, replacement in self.replacers.items():
+ s = re.sub(pattern, replacement, s)
+
+ s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
+ s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
+ s = remove_symbols_and_diacritics(s, keep=".%$ยขโฌยฃ") # keep numeric symbols
+
+ s = self.standardize_numbers(s)
+ s = self.standardize_spellings(s)
+
+ # now remove prefix/suffix symbols that are not preceded/followed by numbers
+ s = re.sub(r"[.$ยขโฌยฃ]([^0-9])", r" \1", s)
+ s = re.sub(r"([^0-9])%", r"\1 ", s)
+
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
+
+ return s
diff --git a/lmms_eval/tasks/mix_evals/audio2text/_default_template_yaml b/lmms_eval/tasks/mix_evals/audio2text/_default_template_yaml
new file mode 100644
index 00000000..459386ab
--- /dev/null
+++ b/lmms_eval/tasks/mix_evals/audio2text/_default_template_yaml
@@ -0,0 +1,10 @@
+dataset_kwargs:
+ token: true
+dataset_path: lmms-lab/MixEval-X-audio2text
+lmms_eval_specific_kwargs:
+ default:
+ post_prompt: ""
+ pre_prompt: ""
+metadata:
+ gpt_eval_model_name: gpt-4o-mini
+ version: 0
\ No newline at end of file
diff --git a/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2_text_freeform.yaml b/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2_text_freeform.yaml
new file mode 100644
index 00000000..8d69fd2b
--- /dev/null
+++ b/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2_text_freeform.yaml
@@ -0,0 +1,24 @@
+task: "mix_evals_audio2text_freeform"
+test_split: free_form
+output_type: generate_until
+doc_to_visual: !function utils.mix_evals_audio2text_doc_to_audio
+doc_to_text: !function utils.mix_evals_audio2text_doc_to_text
+doc_to_target: !function utils.mix_evals_audio2text_doc_to_target
+process_results: !function utils.mix_evals_audio2text_process_results_freeform
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.mix_evals_audio2text_gpt_eval
+ higher_is_better: true
+
+generation_kwargs:
+ max_new_tokens: 64
+
+include: _default_template_yaml
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Answer the question using a single word or phrase."
+ gpt4v:
+ pre_prompt: "These are frames from a video. Please answer the following questions about the video with a short phrase."
+ post_prompt: ""
diff --git a/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2_text_freeform_hard.yaml b/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2_text_freeform_hard.yaml
new file mode 100644
index 00000000..8fdbd3c1
--- /dev/null
+++ b/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2_text_freeform_hard.yaml
@@ -0,0 +1,24 @@
+task: "mix_evals_audio2text_freeform_hard"
+test_split: free_form_hard
+output_type: generate_until
+doc_to_visual: !function utils.mix_evals_audio2text_doc_to_audio
+doc_to_text: !function utils.mix_evals_audio2text_doc_to_text
+doc_to_target: !function utils.mix_evals_audio2text_doc_to_target
+process_results: !function utils.mix_evals_audio2text_process_results_freeform
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.mix_evals_audio2text_gpt_eval
+ higher_is_better: true
+
+generation_kwargs:
+ max_new_tokens: 64
+
+include: _default_template_yaml
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Answer the question using a single word or phrase."
+ gpt4v:
+ pre_prompt: "These are frames from a video. Please answer the following questions about the video with a short phrase."
+ post_prompt: ""
diff --git a/lmms_eval/tasks/mix_evals/audio2text/utils.py b/lmms_eval/tasks/mix_evals/audio2text/utils.py
new file mode 100644
index 00000000..6210a51b
--- /dev/null
+++ b/lmms_eval/tasks/mix_evals/audio2text/utils.py
@@ -0,0 +1,159 @@
+import datetime
+import json
+import os
+import re
+import sys
+import time
+from pathlib import Path
+
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+import lmms_eval.tasks._task_utils.file_utils as file_utils
+from lmms_eval.filters.extraction import ExtendedRegexFilter
+
+with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ config = yaml.safe_load("".join(safe_data))
+
+NUM_SECONDS_TO_SLEEP = 5
+GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
+API_TYPE = os.getenv("API_TYPE", "openai")
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "Authorization": f"Bearer {API_KEY}",
+ "Content-Type": "application/json",
+ }
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "api-key": API_KEY,
+ "Content-Type": "application/json",
+ }
+
+eval_prompt = """You are an AI assistant who will help me to evaluate the quality of a model response to a few candidate ground truth answers.
+
+Some criterion
+- Response that perfectly reflect the meaning of the ground truth: 1 point
+- Response that reflect none of the key points in the ground truth: 0 point
+- Some part in the response are correct but some parts in the ground truth are not mentioned in the response: 0.5 point
+- Some part in the response are correct but other parts in the response are not mentioned in the ground truth: 0.5 point
+
+Here're some examples about the scoring criterion and format:
+model response: Steam Cleaning Services
+ground truth: ["steam clean", "steam clean", "cleaning", "car", "steam clean"],
+Point: 1
+
+model response: A cowboy action shooter.
+ground truth: ["man"]
+Point: 1
+
+model response: I'm sorry, but I can't assist with that request.
+ground truth: ["quality"]
+Point: 0
+
+Let's begin this task:
+model response: {model_response}
+ground truth: {ground_truth}
+Point:"""
+
+
+def get_eval(model_response: str, ground_truth: str, max_tokens: int, retries: int = 5):
+ global headers
+ content = eval_prompt.format(model_response=model_response, ground_truth=ground_truth)
+
+ messages = [
+ {"role": "user", "content": content},
+ ]
+
+ payload = {
+ "model": GPT_EVAL_MODEL_NAME,
+ "messages": messages,
+ "temperature": 0.2,
+ "max_tokens": max_tokens,
+ }
+
+ for attempt in range(retries):
+ try:
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
+ response.raise_for_status()
+ response_data = response.json()
+
+ content = response_data["choices"][0]["message"]["content"].strip()
+ if content != "":
+ return content, response_data["model"]
+ break # If successful, break out of the loop
+
+ except Exception as e:
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ else: # If this was the last attempt, log and return empty
+ eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
+ return "0", ""
+ return "", ""
+
+
+def mix_evals_audio2text_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def mix_evals_audio2text_doc_to_target(doc):
+ return doc["reference_answer"][0]
+
+
+# This is the place where you format your question
+def mix_evals_audio2text_doc_to_text(doc, lmms_eval_specific_kwargs=None):
+ if lmms_eval_specific_kwargs is None:
+ lmms_eval_specific_kwargs = {}
+ pre_prompt = ""
+ post_prompt = ""
+ if "pre_prompt" in lmms_eval_specific_kwargs:
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ if "post_prompt" in lmms_eval_specific_kwargs:
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+
+ user_prompt = doc["query"]
+
+ if pre_prompt:
+ user_prompt = f"{pre_prompt}\n{user_prompt}"
+
+ if post_prompt:
+ user_prompt = f"{user_prompt}\n{post_prompt}"
+ return user_prompt
+
+
+def mix_evals_audio2text_process_results_freeform(doc, result):
+ pred = result[0]
+ ground_truth_str = doc["reference_answer"][0]
+ content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str)
+ eval_answer, model_name = get_eval(model_response=pred, ground_truth=ground_truth_str, max_tokens=1024)
+ return {
+ "gpt_eval": {"pred": pred, "id": doc["id"], "target": ground_truth_str, "eval_answer": eval_answer, "gpt_prompt": content},
+ }
+
+
+def mix_evals_audio2text_gpt_eval(results, args):
+ score = 0
+ for result in results:
+ eval_answer = result["eval_answer"]
+ eval_score = re.search(r"([0-9.]+)", eval_answer).group(1)
+ try:
+ eval_score = float(eval_score)
+ except Exception as e:
+ eval_logger.error(f"Error parsing eval_score: {e}")
+ eval_score = 0.0
+ score += eval_score
+
+ return score / len(results)
diff --git a/lmms_eval/tasks/muchomusic/muchomusic.yaml b/lmms_eval/tasks/muchomusic/muchomusic.yaml
new file mode 100644
index 00000000..6876914d
--- /dev/null
+++ b/lmms_eval/tasks/muchomusic/muchomusic.yaml
@@ -0,0 +1,25 @@
+dataset_path: lmms-lab/muchomusic
+dataset_kwargs:
+ token: True
+
+task: "muchomusic"
+test_split: test
+doc_to_target: !function utils.muchomusic_doc_to_target
+doc_to_visual: !function utils.muchomusic_doc_to_audio
+doc_to_text: !function utils.muchomusic_doc_to_text
+doc_to_choice: !function utils.muchomusic_doc_to_choice
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer with the option's letter from the given choices directly: "
+metric_list:
+ - metric: accuracy
+ aggregation: mean
+ higher_is_better: true
+
+process_results: !function utils.muchomusic_process_results
+
+metadata:
+ gpt_eval_model_name: gpt-4o
+ version: 0.0
diff --git a/lmms_eval/tasks/muchomusic/utils.py b/lmms_eval/tasks/muchomusic/utils.py
new file mode 100644
index 00000000..a01fd59f
--- /dev/null
+++ b/lmms_eval/tasks/muchomusic/utils.py
@@ -0,0 +1,87 @@
+import datetime
+import json
+import os
+import random
+import re
+import sys
+import time
+from pathlib import Path
+
+import numpy as np
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
+
+
+def muchomusic_doc_to_audio(doc):
+ return [doc["context"]]
+
+
+def muchomusic_doc_to_text(doc, lmms_eval_specific_kwargs):
+ question = doc["instruction"]
+ answers = doc["choices"]
+ question = f"{question}\n{answers}"
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{question}{post_prompt}"
+
+
+def muchomusic_doc_to_target(doc):
+ return doc["answer"][1]
+
+
+def muchomusic_doc_to_choice(doc):
+ return ["A", "B", "C", "D"]
+
+
+def muchomusic_process_results(doc, result):
+ response = result[0].strip()
+ all_choices = ["A", "B", "C", "D"]
+ pred = parse_multi_choice_response(response, all_choices) # AdaptfromMMMU
+ gt_ans = doc["answer"][1]
+ score = 1.0 if pred == gt_ans else 0.0
+ return {"accuracy": score}
+
+
+def parse_multi_choice_response(response, all_choices):
+ """
+ Parse the prediction from the generated response.
+ Return the predicted choice letter e.g., A, B, C, D.
+ """
+ # Clean response of unwanted characters
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
+ response = response.strip(char)
+ response = " " + response + " " # Add space to avoid partial match
+
+ candidates = []
+ # Look for choices with parentheses, e.g., (A)
+ for choice in all_choices:
+ if f"({choice})" in response:
+ candidates.append(choice)
+
+ # Look for simple choices, e.g., A, B, C
+ if len(candidates) == 0:
+ for choice in all_choices:
+ if f" {choice} " in response:
+ candidates.append(choice)
+
+ # Look for choices with periods, e.g., A., B., C.
+ if len(candidates) == 0:
+ for choice in all_choices:
+ if f"{choice}." in response:
+ candidates.append(choice)
+
+ # If no candidates, randomly choose one
+ if len(candidates) == 0:
+ pred_index = random.choice(all_choices)
+ elif len(candidates) > 1:
+ # If more than one candidate, choose the last one found
+ start_indexes = [response.rfind(f" {can} ") for can in candidates]
+ pred_index = candidates[np.argmax(start_indexes)]
+ else:
+ # If only one candidate, use it
+ pred_index = candidates[0]
+
+ return pred_index
diff --git a/lmms_eval/tasks/openhermes/openhermes.yaml b/lmms_eval/tasks/openhermes/openhermes.yaml
new file mode 100644
index 00000000..0e94b6f7
--- /dev/null
+++ b/lmms_eval/tasks/openhermes/openhermes.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/openhermes_instruction
+dataset_kwargs:
+ token: True
+
+task: "openhermes"
+test_split: test
+doc_to_target: "answer"
+doc_to_visual: !function utils.doc_to_audio
+doc_to_text: !function utils.doc_to_text
+
+generation_kwargs:
+ max_new_tokens: 1024
+ temperature: 0.2
+ top_p: 1.0
+ num_beams: 1
+
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nPlease give a detail answer to the question in the audio."
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.openhermes_aggregate_results
+ higher_is_better: true
+
+process_results: !function utils.openhermes_process_results
+
+metadata:
+ gpt_eval_model_name: gpt-4o
+ version: 0.0
diff --git a/lmms_eval/tasks/openhermes/utils.py b/lmms_eval/tasks/openhermes/utils.py
new file mode 100644
index 00000000..e5522c00
--- /dev/null
+++ b/lmms_eval/tasks/openhermes/utils.py
@@ -0,0 +1,139 @@
+import datetime
+import json
+import os
+import random
+import re
+import sys
+import time
+from pathlib import Path
+
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+import lmms_eval.tasks._task_utils.file_utils as file_utils
+from lmms_eval.filters.extraction import ExtendedRegexFilter
+
+
+def doc_to_audio(doc):
+ return [doc["context"]]
+
+
+def doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{post_prompt}"
+
+
+with open(Path(__file__).parent / "openhermes.yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ config = yaml.safe_load("".join(safe_data))
+
+# specify api type and key in .env
+GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
+API_TYPE = os.getenv("API_TYPE", "azure")
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "Authorization": f"Bearer {API_KEY}",
+ "Content-Type": "application/json",
+ }
+
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "api-key": API_KEY,
+ "Content-Type": "application/json",
+ }
+
+eval_prompt = """
+ [Question]
+ {question}
+
+ [Reference Answer]
+ {ground_truth}
+
+ [Model Answer]
+ {model_response}
+
+ [Task]
+ Rate the model's answer based on its alignment with the reference answer, focusing on accuracy and relevance to the reference provided. Please be critical on the details.
+ Criteria: Assess if the model's response mirrors the reference in terms of content, accuracy, and relevance.
+ Score0: The answer is completely misaligned, providing incorrect or irrelevant information compared to the reference.
+ Score1: The answer shows minimal alignment, often misunderstanding or providing irrelevant details unrelated to the reference.
+ Score2: The answer recognizes the topic but diverges significantly from the reference in accuracy or relevance.
+ Score3: The answer aligns with the reference generally but lacks detail or precise accuracy in some aspects.
+ Score4: The answer is mostly accurate and relevant, closely following the reference but could be clearer or more detailed.
+ Score5: The answer is highly accurate, detailed, and matches the reference answer perfectly, capturing its essence and detail.
+
+ Your response should be formatted as follows:
+ Explanation: (Provide a concise explanation of your rating, comparing the reference answer with the model's response. "The reference answer is [XXX], while the model's answer is [YYY]. I think ...")
+ Rating: (int)"""
+
+
+retries = 3
+NUM_SECONDS_TO_SLEEP = 5
+
+
+def get_eval(max_tokens: int, content: str, retries: int = retries):
+ global headers
+
+ messages = [
+ {"role": "user", "content": content},
+ ]
+
+ payload = {"model": GPT_EVAL_MODEL_NAME, "messages": messages, "temperature": 0.7, "max_tokens": max_tokens, "top_p": 0.95, "frequency_penalty": 0, "presence_penalty": 0, "stop": None}
+
+ for attempt in range(retries):
+ try:
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
+ response.raise_for_status()
+ response_data = response.json()
+
+ content = response_data["choices"][0]["message"]["content"].strip()
+ if content != "":
+ return content, response_data["model"]
+ break # If successful, break out of the loop
+
+ except Exception as e:
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ else: # If this was the last attempt, log and return empty
+ eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
+ return "", ""
+ return "", ""
+
+
+def openhermes_process_results(doc, result):
+ pred = result[0]
+ ground_truth_str = doc["answer"]
+ content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str, question=doc["speech_instruction"])
+ eval_answer, model_name = get_eval(max_tokens=1024, content=content)
+ return {
+ "gpt_eval": {"eval_answer": eval_answer, "model_name": model_name},
+ }
+
+
+def openhermes_aggregate_results(results):
+ score = 0
+ for result in results:
+ try:
+ eval_answer = result["eval_answer"]
+ eval_score = re.search(r"([0-5])", eval_answer).group(1)
+ eval_score = float(eval_score)
+ except Exception as e:
+ eval_logger.error(f"Error parsing eval_score: {e}")
+ eval_score = 0.0
+ score += eval_score
+
+ return score / len(results) * 20
diff --git a/lmms_eval/tasks/people_speech/people_speech_val.yaml b/lmms_eval/tasks/people_speech/people_speech_val.yaml
new file mode 100644
index 00000000..b680549d
--- /dev/null
+++ b/lmms_eval/tasks/people_speech/people_speech_val.yaml
@@ -0,0 +1,29 @@
+dataset_path: lmms-lab/peoples_speech
+dataset_kwargs:
+ token: True
+task : "people_speech_val"
+test_split: val
+output_type: generate_until
+doc_to_visual: !function utils.people_speech_doc_to_audio
+doc_to_text: !function utils.people_speech_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.people_speech_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.people_speech_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
diff --git a/lmms_eval/tasks/people_speech/utils.py b/lmms_eval/tasks/people_speech/utils.py
new file mode 100644
index 00000000..af26f832
--- /dev/null
+++ b/lmms_eval/tasks/people_speech/utils.py
@@ -0,0 +1,182 @@
+import os
+import re
+import unicodedata
+
+import editdistance as ed
+import zhconv
+
+from lmms_eval.tasks.librispeech.cn_tn import TextNorm
+from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer
+from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer
+
+# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'.
+english_normalizer = EnglishTextNormalizer()
+chinese_normalizer = TextNorm(
+ to_banjiao=False,
+ to_upper=False,
+ to_lower=False,
+ remove_fillers=False,
+ remove_erhua=False,
+ check_chars=False,
+ remove_space=False,
+ cc_mode="",
+)
+basic_normalizer = BasicTextNormalizer()
+
+dir_name = os.path.dirname(os.path.abspath(__file__))
+
+
+def people_speech_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def people_speech_doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"
+
+
+def people_speech_process_result(doc, result):
+ pred = result[0] if len(result) > 0 else ""
+
+ gt = doc["gt"]
+ source = doc["source"]
+ task = doc["task"]
+
+ data_dict = {"gt": gt, "pred": pred, "source": source, "task": task}
+
+ return {"wer": data_dict}
+
+
+PUNCS = "!,.?;:"
+
+
+def remove_sp(text, language):
+ gt = re.sub(r"<\|.*?\|>", " ", text)
+ gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space.
+ gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
+ gt = gt.lstrip(" ")
+ if language == "zh":
+ gt = re.sub(rf"\s+", r"", gt)
+ return gt
+
+
+class EvaluationTokenizer(object):
+ """A generic evaluation-time tokenizer, which leverages built-in tokenizers
+ in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
+ lowercasing, punctuation removal and character tokenization, which are
+ applied after sacreBLEU tokenization.
+
+ Args:
+ tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
+ lowercase (bool): lowercase the text.
+ punctuation_removal (bool): remove punctuation (based on unicode
+ category) from text.
+ character_tokenization (bool): tokenize the text to characters.
+ """
+
+ SPACE = chr(32)
+ SPACE_ESCAPE = chr(9601)
+ # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
+
+ def __init__(
+ self,
+ tokenizer_type: str = "13a",
+ lowercase: bool = False,
+ punctuation_removal: bool = False,
+ character_tokenization: bool = False,
+ ):
+ from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
+ from sacrebleu.tokenizers.tokenizer_char import TokenizerChar
+ from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International
+ from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab
+ from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer
+ from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh
+
+ TOKENIZERS = {
+ "none": NoneTokenizer,
+ "13a": Tokenizer13a,
+ "intl": TokenizerV14International,
+ "zh": TokenizerZh,
+ "ja-mecab": TokenizerJaMecab,
+ "char": TokenizerChar,
+ }
+
+ assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
+ self.lowercase = lowercase
+ self.punctuation_removal = punctuation_removal
+ self.character_tokenization = character_tokenization
+ self.tokenizer = TOKENIZERS[tokenizer_type]
+ # self.tokenizer = tokenizer_none
+
+ @classmethod
+ def remove_punctuation(cls, sent: str):
+ """Remove punctuation based on Unicode category."""
+ return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t))
+
+ def tokenize(self, sent: str):
+ tokenized = self.tokenizer()(sent)
+
+ if self.punctuation_removal:
+ tokenized = self.remove_punctuation(tokenized)
+
+ if self.character_tokenization:
+ tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
+
+ if self.lowercase:
+ tokenized = tokenized.lower()
+
+ return tokenized
+
+
+def compute_wer(refs, hyps, language):
+ distance = 0
+ ref_length = 0
+ tokenizer = EvaluationTokenizer(
+ tokenizer_type="none",
+ lowercase=True,
+ punctuation_removal=True,
+ character_tokenization=False,
+ )
+ for i in range(len(refs)):
+ ref = refs[i]
+ pred = hyps[i]
+ if language in ["yue"]:
+ ref = zhconv.convert(ref, "zh-cn")
+ pred = zhconv.convert(pred, "zh-cn")
+ if language in ["en"]:
+ ref = english_normalizer(ref)
+ pred = english_normalizer(pred)
+ if language in ["zh"]:
+ ref = chinese_normalizer(ref)
+ pred = chinese_normalizer(pred)
+ else:
+ ref = basic_normalizer(ref)
+ pred = basic_normalizer(pred)
+ ref_items = tokenizer.tokenize(ref).split()
+ pred_items = tokenizer.tokenize(pred).split()
+ if language in ["zh", "yue"]:
+ ref_items = [x for x in "".join(ref_items)]
+ pred_items = [x for x in "".join(pred_items)]
+ if i == 0:
+ print(f"ref: {ref}")
+ print(f"pred: {pred}")
+ print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
+ print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
+ distance += ed.eval(ref_items, pred_items)
+ ref_length += len(ref_items)
+ return distance / ref_length
+
+
+def people_speech_wer(results, args):
+ refs, hyps = [], []
+ for result in results:
+ lan = result["task"][4:]
+ gt = result["gt"]
+ response = result["pred"]
+ gt = remove_sp(gt, lan)
+ response = remove_sp(response, lan)
+ refs.append(gt)
+ hyps.append(response)
+ wer = compute_wer(refs, hyps, lan)
+ return wer * 100
diff --git a/lmms_eval/tasks/tedlium/tedlium_dev_test.yaml b/lmms_eval/tasks/tedlium/tedlium_dev_test.yaml
new file mode 100644
index 00000000..18810f4c
--- /dev/null
+++ b/lmms_eval/tasks/tedlium/tedlium_dev_test.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/tedlium
+dataset_kwargs:
+ token: True
+task : "tedlium_dev_test"
+test_split: val
+dataset_name: tedlium_dev_test
+output_type: generate_until
+doc_to_visual: !function utils.tedlium_doc_to_audio
+doc_to_text: !function utils.tedlium_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.tedlium_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.tedlium_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/tedlium/tedlium_long_form.yaml b/lmms_eval/tasks/tedlium/tedlium_long_form.yaml
new file mode 100644
index 00000000..f72983b7
--- /dev/null
+++ b/lmms_eval/tasks/tedlium/tedlium_long_form.yaml
@@ -0,0 +1,30 @@
+dataset_path: lmms-lab/tedlium
+dataset_kwargs:
+ token: True
+task : "tedlium_long_form"
+test_split: val
+dataset_name: tedlium_long_form
+output_type: generate_until
+doc_to_visual: !function utils.tedlium_doc_to_audio
+doc_to_text: !function utils.tedlium_doc_to_text
+doc_to_target: "gt"
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.tedlium_process_result
+metric_list:
+ - metric: wer
+ aggregation : !function utils.tedlium_wer
+ higher_is_better : false
+metadata:
+ - version: 0.0
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+ qwen2_audio:
+ pre_prompt: ""
+ post_prompt: " <|en|>"
\ No newline at end of file
diff --git a/lmms_eval/tasks/tedlium/utils.py b/lmms_eval/tasks/tedlium/utils.py
new file mode 100644
index 00000000..606b15e3
--- /dev/null
+++ b/lmms_eval/tasks/tedlium/utils.py
@@ -0,0 +1,182 @@
+import os
+import re
+import unicodedata
+
+import editdistance as ed
+import zhconv
+
+from lmms_eval.tasks.librispeech.cn_tn import TextNorm
+from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer
+from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer
+
+# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'.
+english_normalizer = EnglishTextNormalizer()
+chinese_normalizer = TextNorm(
+ to_banjiao=False,
+ to_upper=False,
+ to_lower=False,
+ remove_fillers=False,
+ remove_erhua=False,
+ check_chars=False,
+ remove_space=False,
+ cc_mode="",
+)
+basic_normalizer = BasicTextNormalizer()
+
+dir_name = os.path.dirname(os.path.abspath(__file__))
+
+
+def tedlium_doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def tedlium_doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}"
+
+
+def tedlium_process_result(doc, result):
+ pred = result[0] if len(result) > 0 else ""
+
+ gt = doc["gt"]
+ source = doc["source"]
+ task = doc["task"]
+
+ data_dict = {"gt": gt, "pred": pred, "source": source, "task": task}
+
+ return {"wer": data_dict}
+
+
+PUNCS = "!,.?;:"
+
+
+def remove_sp(text, language):
+ gt = re.sub(r"<\|.*?\|>", " ", text)
+ gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space.
+ gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
+ gt = gt.lstrip(" ")
+ if language == "zh":
+ gt = re.sub(rf"\s+", r"", gt)
+ return gt
+
+
+class EvaluationTokenizer(object):
+ """A generic evaluation-time tokenizer, which leverages built-in tokenizers
+ in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
+ lowercasing, punctuation removal and character tokenization, which are
+ applied after sacreBLEU tokenization.
+
+ Args:
+ tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
+ lowercase (bool): lowercase the text.
+ punctuation_removal (bool): remove punctuation (based on unicode
+ category) from text.
+ character_tokenization (bool): tokenize the text to characters.
+ """
+
+ SPACE = chr(32)
+ SPACE_ESCAPE = chr(9601)
+ # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
+
+ def __init__(
+ self,
+ tokenizer_type: str = "13a",
+ lowercase: bool = False,
+ punctuation_removal: bool = False,
+ character_tokenization: bool = False,
+ ):
+ from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
+ from sacrebleu.tokenizers.tokenizer_char import TokenizerChar
+ from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International
+ from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab
+ from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer
+ from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh
+
+ TOKENIZERS = {
+ "none": NoneTokenizer,
+ "13a": Tokenizer13a,
+ "intl": TokenizerV14International,
+ "zh": TokenizerZh,
+ "ja-mecab": TokenizerJaMecab,
+ "char": TokenizerChar,
+ }
+
+ assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
+ self.lowercase = lowercase
+ self.punctuation_removal = punctuation_removal
+ self.character_tokenization = character_tokenization
+ self.tokenizer = TOKENIZERS[tokenizer_type]
+ # self.tokenizer = tokenizer_none
+
+ @classmethod
+ def remove_punctuation(cls, sent: str):
+ """Remove punctuation based on Unicode category."""
+ return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t))
+
+ def tokenize(self, sent: str):
+ tokenized = self.tokenizer()(sent)
+
+ if self.punctuation_removal:
+ tokenized = self.remove_punctuation(tokenized)
+
+ if self.character_tokenization:
+ tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
+
+ if self.lowercase:
+ tokenized = tokenized.lower()
+
+ return tokenized
+
+
+def compute_wer(refs, hyps, language):
+ distance = 0
+ ref_length = 0
+ tokenizer = EvaluationTokenizer(
+ tokenizer_type="none",
+ lowercase=True,
+ punctuation_removal=True,
+ character_tokenization=False,
+ )
+ for i in range(len(refs)):
+ ref = refs[i]
+ pred = hyps[i]
+ if language in ["yue"]:
+ ref = zhconv.convert(ref, "zh-cn")
+ pred = zhconv.convert(pred, "zh-cn")
+ if language in ["en"]:
+ ref = english_normalizer(ref)
+ pred = english_normalizer(pred)
+ if language in ["zh"]:
+ ref = chinese_normalizer(ref)
+ pred = chinese_normalizer(pred)
+ else:
+ ref = basic_normalizer(ref)
+ pred = basic_normalizer(pred)
+ ref_items = tokenizer.tokenize(ref).split()
+ pred_items = tokenizer.tokenize(pred).split()
+ if language in ["zh", "yue"]:
+ ref_items = [x for x in "".join(ref_items)]
+ pred_items = [x for x in "".join(pred_items)]
+ if i == 0:
+ print(f"ref: {ref}")
+ print(f"pred: {pred}")
+ print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
+ print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
+ distance += ed.eval(ref_items, pred_items)
+ ref_length += len(ref_items)
+ return distance / ref_length
+
+
+def tedlium_wer(results, args):
+ refs, hyps = [], []
+ for result in results:
+ lan = result["task"][4:]
+ gt = result["gt"]
+ response = result["pred"]
+ gt = remove_sp(gt, lan)
+ response = remove_sp(response, lan)
+ refs.append(gt)
+ hyps.append(response)
+ wer = compute_wer(refs, hyps, lan)
+ return wer * 100
diff --git a/lmms_eval/tasks/vocalsound/_default_template_yaml b/lmms_eval/tasks/vocalsound/_default_template_yaml
new file mode 100644
index 00000000..73a813a1
--- /dev/null
+++ b/lmms_eval/tasks/vocalsound/_default_template_yaml
@@ -0,0 +1,14 @@
+dataset_path: lmms-lab/vocalsound
+dataset_kwargs:
+ token: True
+doc_to_target: "answer"
+doc_to_visual: !function utils.doc_to_audio
+doc_to_text: !function utils.doc_to_text
+doc_to_choice: !function utils.doc_to_choice
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "Classify the human vocal sound to VocalSound in English: "
+
+metadata:
+ version: 0.0
diff --git a/lmms_eval/tasks/vocalsound/utils.py b/lmms_eval/tasks/vocalsound/utils.py
new file mode 100644
index 00000000..33f89389
--- /dev/null
+++ b/lmms_eval/tasks/vocalsound/utils.py
@@ -0,0 +1,81 @@
+import datetime
+import json
+import os
+import random
+import re
+import sys
+import time
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
+
+
+def doc_to_audio(doc):
+ return [doc["audio"]]
+
+
+def doc_to_text(doc, lmms_eval_specific_kwargs):
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{post_prompt}"
+
+
+classes = ["Laughter", "Sniff", "Throat", "Cough", "Sigh", "Sneeze"]
+
+
+def doc_to_choice(doc):
+ return ["Laughter", "Sniff", "Throat clearing", "Cough", "Sigh", "Sneeze"]
+
+
+def vocalsound_process_results(doc, result):
+ response = result[0].strip()
+ gt_ans = doc["answer"]
+ pred = get_answer(response)
+ score = 1.0 if pred == gt_ans else 0.0
+ return {"accuracy": {"overall": score, "age": doc["age_group"], "spk_id": doc["spk_id"]}}
+
+
+def vocalsound_aggregate_results(results):
+ total_correct = 0
+ group_totals = defaultdict(int)
+ group_correct = defaultdict(int)
+
+ for result in results:
+ accuracy = result["overall"]
+ total_correct += accuracy
+
+ # Gender grouping
+ if result["spk_id"][0] == "f":
+ group_totals["female"] += 1
+ group_correct["female"] += accuracy
+ else:
+ group_totals["male"] += 1
+ group_correct["male"] += accuracy
+
+ # Age grouping
+ age_group = f"age{str(result['age'])}"
+ group_totals[age_group] += 1
+ group_correct[age_group] += accuracy
+
+ return {
+ "overall_accuracy": total_correct / len(results),
+ "categorical_accuracy": {
+ "male_accuracy": round(group_correct["male"] / group_totals.get("male", 1), 5), # Avoid division by zero
+ "female_accuracy": round(group_correct["female"] / group_totals.get("female", 1), 5),
+ "age_18_25_accuracy": round(group_correct["age1"] / group_totals.get("age1", 1), 5),
+ "age_26_48_accuracy": round(group_correct["age2"] / group_totals.get("age2", 1), 5),
+ "age_49_80_accuracy": round(group_correct["age3"] / group_totals.get("age3", 1), 5),
+ },
+ }
+
+
+def get_answer(response):
+ for temp in classes:
+ if temp.lower() in response.lower():
+ return temp if temp != "Throat" else "Throat clearing"
diff --git a/lmms_eval/tasks/vocalsound/vocalsound_test.yaml b/lmms_eval/tasks/vocalsound/vocalsound_test.yaml
new file mode 100644
index 00000000..6e2f24a5
--- /dev/null
+++ b/lmms_eval/tasks/vocalsound/vocalsound_test.yaml
@@ -0,0 +1,13 @@
+task: "vocalsound_test"
+test_split: test
+metric_list:
+ - metric: accuracy
+ aggregation: !function utils.vocalsound_aggregate_results
+ higher_is_better: true
+ # - metric: submission
+ # aggregation: !function utils.vocalsound_aggregate_results_for_submission
+ # higher_is_better: true
+
+process_results: !function utils.vocalsound_process_results
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/vocalsound/vocalsound_val.yaml b/lmms_eval/tasks/vocalsound/vocalsound_val.yaml
new file mode 100644
index 00000000..22e83b9f
--- /dev/null
+++ b/lmms_eval/tasks/vocalsound/vocalsound_val.yaml
@@ -0,0 +1,13 @@
+task: "vocalsound_val"
+test_split: val
+metric_list:
+ - metric: accuracy
+ aggregation: !function utils.vocalsound_aggregate_results
+ higher_is_better: true
+ # - metric: submission
+ # aggregation: !function utils.vocalsound_aggregate_results_for_submission
+ # higher_is_better: true
+
+process_results: !function utils.vocalsound_process_results
+
+include: _default_template_yaml
diff --git a/lmms_eval/tasks/wavcaps/utils.py b/lmms_eval/tasks/wavcaps/utils.py
new file mode 100644
index 00000000..09a4333f
--- /dev/null
+++ b/lmms_eval/tasks/wavcaps/utils.py
@@ -0,0 +1,134 @@
+import os
+import re
+import time
+from pathlib import Path
+
+import requests
+import yaml
+from loguru import logger as eval_logger
+
+
+def wavcaps_doc_to_audio(doc):
+ return [doc["context"]]
+
+
+def wavcaps_doc_to_text(doc, lmms_eval_specific_kwargs):
+ question = doc["instruction"]
+ pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
+ post_prompt = lmms_eval_specific_kwargs["post_prompt"]
+ return f"{pre_prompt}{question}{post_prompt}"
+
+
+# functions for the clotho_asqa_v2 task, need to be tested later
+
+with open(Path(__file__).parent / "wavcaps.yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ config = yaml.safe_load("".join(safe_data))
+
+
+NUM_SECONDS_TO_SLEEP = 5
+GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
+API_TYPE = os.getenv("API_TYPE", "azure")
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "Authorization": f"Bearer {API_KEY}",
+ "Content-Type": "application/json",
+ }
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ headers = {
+ "api-key": API_KEY,
+ "Content-Type": "application/json",
+ }
+
+eval_prompt = """
+ [Question]
+ {question}
+
+ [Reference Answer]
+ {ground_truth}
+
+ [Model Answer]
+ {model_response}
+
+ [Task]
+ Rate the model's answer based on its alignment with the reference answer, focusing on accuracy and relevance to the reference provided. Please be critical on the details.
+ Criteria: Assess if the model's response mirrors the reference in terms of content, accuracy, and relevance.
+ Score0: The answer is completely misaligned, providing incorrect or irrelevant information compared to the reference.
+ Score1: The answer shows minimal alignment, often misunderstanding or providing irrelevant details unrelated to the reference.
+ Score2: The answer recognizes the topic but diverges significantly from the reference in accuracy or relevance.
+ Score3: The answer aligns with the reference generally but lacks detail or precise accuracy in some aspects.
+ Score4: The answer is mostly accurate and relevant, closely following the reference but could be clearer or more detailed.
+ Score5: The answer is highly accurate, detailed, and matches the reference answer perfectly, capturing its essence and detail.
+
+ Your response should be formatted as follows:
+ Explanation: (Provide a concise explanation of your rating, comparing the reference answer with the model's response. "The reference answer is [XXX], while the model's answer is [YYY]. I think ...")
+ Rating: (int)"""
+
+
+# gpt-4
+def get_eval(max_tokens: int, content: str):
+ global headers
+
+ messages = [
+ {"role": "user", "content": content},
+ ]
+
+ payload = {"model": GPT_EVAL_MODEL_NAME, "messages": messages, "temperature": 0, "max_tokens": max_tokens, "n": 1}
+
+ for attempt in range(5):
+ try:
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
+ response.raise_for_status()
+ response_data = response.json()
+
+ content = response_data["choices"][0]["message"]["content"].strip()
+ if content != "":
+ return content, response_data["model"]
+ break # If successful, break out of the loop
+
+ except Exception as e:
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
+ if attempt < 5: # If we have retries left, sleep and then continue to next attempt
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ else: # If this was the last attempt, log and return empty
+ eval_logger.error(f"All 5 attempts failed. Last error message: {e}")
+ return "", ""
+ return "", ""
+
+
+def wavcaps_process_results(doc, results):
+ pred = results[0]
+ ground_truth_str = doc["answer"]
+ content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str, question=doc["instruction"])
+ eval_answer, model_name = get_eval(max_tokens=1024, content=content)
+ return {
+ "gpt_eval": {"eval_answer": eval_answer, "model_name": model_name},
+ }
+
+
+def wavcaps_aggregate_results(results):
+ score = 0
+ for result in results:
+ eval_answer = result["eval_answer"]
+
+ try:
+ match = re.search(r"Rating:\s*([0-5])\s*$", eval_answer)
+ eval_score = match.group(1) if match else 0
+ eval_score = float(eval_score)
+ except Exception as e:
+ eval_logger.error(f"Error parsing eval_score: {e}")
+ eval_score = 0.0
+ score += eval_score
+
+ return score / len(results)
diff --git a/lmms_eval/tasks/wavcaps/wavcaps.yaml b/lmms_eval/tasks/wavcaps/wavcaps.yaml
new file mode 100644
index 00000000..c142e2bb
--- /dev/null
+++ b/lmms_eval/tasks/wavcaps/wavcaps.yaml
@@ -0,0 +1,27 @@
+dataset_path: AudioLLMs/wavcaps_test
+dataset_kwargs:
+ token: True
+task : "wavcaps"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.wavcaps_doc_to_audio
+doc_to_text: !function utils.wavcaps_doc_to_text
+doc_to_target: "answer"
+lmms_eval_specific_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: ""
+generation_kwargs:
+ max_new_tokens: 256
+ temperature: 0
+ top_p: 1.0
+ num_beams: 1
+ do_sample: false
+process_results: !function utils.wavcaps_process_results
+metric_list:
+ - metric: gpt_eval
+ aggregation: !function utils.wavcaps_aggregate_results
+ higher_is_better: true
+metadata:
+ gpt_eval_model_name: gpt-4o
+ version: 0.0
diff --git a/lmms_eval/utils.py b/lmms_eval/utils.py
index 1df086ae..1604f9a1 100755
--- a/lmms_eval/utils.py
+++ b/lmms_eval/utils.py
@@ -641,6 +641,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
yaml_dir = os.path.dirname(yaml_path)
assert yaml_dir is not None
+ assert yaml_config is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
diff --git a/pyproject.toml b/pyproject.toml
index bccee680..0ae7a953 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -74,6 +74,13 @@ dependencies = [
]
[project.optional-dependencies]
+audio = [
+ "more-itertools",
+ "editdistance",
+ "zhconv",
+ "librosa",
+ "soundfile"
+]
metrics = [
"pywsd",
"spacy",
diff --git a/test_parse.py b/test_parse.py
new file mode 100644
index 00000000..ae69e337
--- /dev/null
+++ b/test_parse.py
@@ -0,0 +1,20 @@
+from lmms_eval.filters.extraction import MultiChoiceRegexFilter
+
+
+def parse_multi_choice_answer(answer):
+ # Example responses and documents
+ model_responses = [["The answer is (B)", "I believe it is (A)", "(C) seems correct"], ["Answer is: B!", "Answer: B", "Answer: B"]] # Model response set 1 # Model response set 2
+
+ documents = [{"choices": ["A. Apple", "B. Banana", "C. Cherry"]}, {"choices": ["A. Alpha", "B. Beta", "C. Gamma"]}] # Multiple choice options for question 1 # Multiple choice options for question 2
+
+ # Instantiate the filter
+ multi_choice_filter = MultiChoiceRegexFilter(regex_pattern=r"\(([A-D])\)", group_select=0, ignore_case=False, ignore_punctuation=True)
+
+ filtered_responses = multi_choice_filter.apply(model_responses, documents)
+
+ # Print the filtered answers
+ for i, filtered in enumerate(filtered_responses):
+ print(f"Question {i+1} filtered responses: {filtered}")
+
+
+parse_multi_choice_answer("a")
diff --git a/tools/make_audio_hf_dataset.ipynb b/tools/make_audio_hf_dataset.ipynb
new file mode 100644
index 00000000..ad939f87
--- /dev/null
+++ b/tools/make_audio_hf_dataset.ipynb
@@ -0,0 +1,139 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datasets\n",
+ "import json\n",
+ "\n",
+ "# Define the dataset features (audio, text, and source)\n",
+ "# change the data structure according to your needs, only important changes here is using datasets.Audio to load audio file\n",
+ "# And provide audio path in the data construction\n",
+ "# once loaded through datasets.Audio, we can access audio data, in the form of np.array(float32) using doc[\"audio\"][\"array\"]\n",
+ "features = datasets.Features(\n",
+ " {\n",
+ " \"audio\": datasets.Audio(sampling_rate=16000),\n",
+ " \"prompt\": datasets.Value(\"string\"),\n",
+ " \"gt\": datasets.Value(\"string\"),\n",
+ " \"source\": datasets.Value(\"string\"),\n",
+ " \"task\": datasets.Value(\"string\"),\n",
+ " }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# loading data into dict form\n",
+ "def load_audio_data(data_path):\n",
+ " with open(data_path, 'r') as f:\n",
+ " data_lines = f.readlines()\n",
+ "\n",
+ " audio_list = []\n",
+ " prompt_list = []\n",
+ " gt_list = []\n",
+ " source_list = []\n",
+ " task_list = []\n",
+ "\n",
+ " for line in data_lines:\n",
+ " json_data = json.loads(line.strip())\n",
+ "\n",
+ " audio_list.append(json_data['audio']) # Path to the actual audio file\n",
+ " prompt_list.append(\"<|audio_bos|><|AUDIO|><|audio_eos|>\" + json_data['prompt'])\n",
+ " gt_list.append(json_data['gt'])\n",
+ " source_list.append(json_data['source'])\n",
+ " task_list.append(json_data['task'])\n",
+ "\n",
+ " # Return a dictionary where keys are features and values are lists of data\n",
+ " return {\n",
+ " 'audio': audio_list,\n",
+ " 'prompt': prompt_list,\n",
+ " 'gt': gt_list,\n",
+ " 'source': source_list,\n",
+ " 'task': task_list\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load data according to different task\n",
+ "def load_audio_data_task(data_path, task):\n",
+ " with open(data_path, 'r') as f:\n",
+ " data_lines = f.readlines()\n",
+ "\n",
+ " audio_list = []\n",
+ " prompt_list = []\n",
+ " gt_list = []\n",
+ " source_list = []\n",
+ " task_list = []\n",
+ "\n",
+ " for line in data_lines:\n",
+ " json_data = json.loads(line.strip())\n",
+ " if json_data['source'] == task: \n",
+ "\n",
+ " \n",
+ " audio_list.append(json_data['audio']) # Path to the actual audio file\n",
+ " prompt_list.append(\"<|audio_bos|><|AUDIO|><|audio_eos|>\" + json_data['prompt'])\n",
+ " gt_list.append(json_data['gt'])\n",
+ " source_list.append(json_data['source'])\n",
+ " task_list.append(json_data['task'])\n",
+ "\n",
+ " # Return a dictionary where keys are features and values are lists of data\n",
+ " return {\n",
+ " 'audio': audio_list,\n",
+ " 'prompt': prompt_list,\n",
+ " 'gt': gt_list,\n",
+ " 'source': source_list,\n",
+ " 'task': task_list\n",
+ " }\n",
+ "\n",
+ "\n",
+ "tasks = ['librispeech_test_other', 'librispeech_dev_other', 'librispeech_test_clean', 'librispeech_dev_clean']\n",
+ "\n",
+ "# description_root\n",
+ "data_description_path = \"./librispeech_eval.jsonl\"\n",
+ "\n",
+ "data_dict = {}\n",
+ "for task in tasks:\n",
+ "\n",
+ " # Load the dataset into a Hugging Face Dataset object\n",
+ " data = load_audio_data_task(data_description_path, task)\n",
+ "\n",
+ " # Create a Dataset from the data and features\n",
+ " dataset = datasets.Dataset.from_dict(data, features=features)\n",
+ "\n",
+ " # Verify the dataset structure\n",
+ " print(dataset)\n",
+ "\n",
+ " data_dict[task] = dataset\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = datasets.DatasetDict(data_dict)\n",
+ "data.push_to_hub(\"Alarak/librispeech\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}