Skip to content

Commit

Permalink
[Optimization] Optimize build requests, postprocess (#479)
Browse files Browse the repository at this point in the history
* Add process with media args for faster postprocess

* Use iterator for building requests

* Remove audio from media dataset
  • Loading branch information
kcz358 authored Dec 27, 2024
1 parent 3c00482 commit c6e9a36
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
1 change: 1 addition & 0 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def parse_eval_args() -> argparse.Namespace:
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument("--process_with_media", action="store_true", help="Whether you will process you dataset with audio, image. By default set to False" "In case some benchmarks need to be processed with media, set this flag to True.")
args = parser.parse_args()
return args

Expand Down
26 changes: 23 additions & 3 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import datasets
import numpy as np
from accelerate import Accelerator
from datasets import DownloadConfig, Image, Sequence
from datasets import Audio, DownloadConfig, Image, Sequence
from huggingface_hub import snapshot_download
from loguru import logger as eval_logger
from PIL import ImageFile
Expand Down Expand Up @@ -430,9 +430,10 @@ def build_all_requests(
if cache_requests and (not cached_instances or rewrite_requests_cache) and limit is not None:
limit = None

doc_id_docs = list(self.doc_iterator(rank=rank, limit=limit, world_size=world_size))
doc_id_docs = self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
doc_iterator_for_counting = itertools.islice(range(len(self.test_docs())), rank, limit, world_size) if self.has_test_docs() else itertools.islice(range(len(self.validation_docs())), rank, limit, world_size)

num_docs = len(doc_id_docs)
num_docs = sum(1 for _ in doc_iterator_for_counting)

for doc_id, doc in tqdm(
doc_id_docs,
Expand Down Expand Up @@ -1064,6 +1065,8 @@ def concat_tar_parts(tar_parts, output_tar):
remove_cols.append(feature)
elif isinstance(features[feature], Sequence) and isinstance(features[feature].feature, Image):
remove_cols.append(feature)
elif isinstance(features[feature], Audio):
remove_cols.append(feature)
for remove_col in remove_cols:
self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(remove_col)

Expand Down Expand Up @@ -1093,10 +1096,27 @@ def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs():
return self.dataset[self.config.validation_split]

def validation_docs_no_media(self) -> datasets.Dataset:
if self.has_validation_docs():
return self.dataset_no_image[self.config.validation_split]

def test_docs(self) -> datasets.Dataset:
if self.has_test_docs():
return self.dataset[self.config.test_split]

def test_docs_no_media(self) -> datasets.Dataset:
if self.has_test_docs():
return self.dataset_no_image[self.config.test_split]

@property
def eval_docs_no_media(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
return self.test_docs_no_media()
elif self.has_validation_docs():
return self.validation_docs_no_media()
else:
raise ValueError(f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!")

def fewshot_docs(self):
if self.config.fewshot_split is not None:
return self.dataset[self.config.fewshot_split]
Expand Down
5 changes: 4 additions & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,10 @@ def evaluate(
instances.sort(key=lambda x: x.idx)
# iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys():
doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE)
if not cli_args.process_with_media:
doc_iterator = create_iterator(enumerate(task.eval_docs_no_media), rank=RANK, limit=int(limit) if limit else None, world_size=WORLD_SIZE)
else:
doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE)
doc_iterator_for_counting = itertools.islice(range(len(task.test_docs())), RANK, limit, WORLD_SIZE) if task.has_test_docs() else itertools.islice(range(len(task.validation_docs())), RANK, limit, WORLD_SIZE)
total_docs = sum(1 for _ in doc_iterator_for_counting)
pbar = tqdm(total=total_docs, desc=f"Postprocessing", disable=(RANK != 0))
Expand Down

0 comments on commit c6e9a36

Please sign in to comment.