From 624dd4c2d1201b1ebc717b30f98d6ce8f47477d1 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Wed, 21 Aug 2024 10:29:48 +0200 Subject: [PATCH] chore: improve benchmarking abstraction --- revision-scripts/mp_classification.py | 4 ++-- src/mattext/models/benchmark.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/revision-scripts/mp_classification.py b/revision-scripts/mp_classification.py index 219caec..763e5d1 100644 --- a/revision-scripts/mp_classification.py +++ b/revision-scripts/mp_classification.py @@ -20,8 +20,8 @@ def __len__(self): return self.txn.stat()['entries'] def get(self, index): - id = f"{index}".encode("ascii") - return pickle.loads(self.txn.get(id)) + id_ = f"{index}".encode("ascii") + return pickle.loads(self.txn.get(id_)) def create_json_from_lmdb(lmdb_path, output_dir): dataset = Dataset(lmdb_path) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 235c325..b79bfd5 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -138,7 +138,7 @@ def _record_predictions(self, task, fold, predictions, prediction_ids): class MatbenchmarkClassification(BaseBenchmark): def run_benchmarking(self, local_rank=None) -> None: - task = MatTextTask(task_name=self.task, is_classification=True) + task = self._initialize_task() for i, (exp_name, test_name) in enumerate( zip(self.exp_names, self.test_exp_names) @@ -152,6 +152,9 @@ def run_benchmarking(self, local_rank=None) -> None: self._save_results(task) + def _initialize_task(self): + return MatTextTask(task_name=self.task, is_classification=True) + def _get_finetuner(self, exp_cfg, local_rank, fold_name): return FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name)