Skip to content

Commit

Permalink
chore: improve benchmarking abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
n0w0f committed Aug 21, 2024
1 parent e26424b commit 624dd4c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions revision-scripts/mp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def __len__(self):
return self.txn.stat()['entries']

def get(self, index):
id = f"{index}".encode("ascii")
return pickle.loads(self.txn.get(id))
id_ = f"{index}".encode("ascii")
return pickle.loads(self.txn.get(id_))

def create_json_from_lmdb(lmdb_path, output_dir):
dataset = Dataset(lmdb_path)
Expand Down
5 changes: 4 additions & 1 deletion src/mattext/models/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _record_predictions(self, task, fold, predictions, prediction_ids):

class MatbenchmarkClassification(BaseBenchmark):
def run_benchmarking(self, local_rank=None) -> None:
task = MatTextTask(task_name=self.task, is_classification=True)
task = self._initialize_task()

for i, (exp_name, test_name) in enumerate(
zip(self.exp_names, self.test_exp_names)
Expand All @@ -152,6 +152,9 @@ def run_benchmarking(self, local_rank=None) -> None:

self._save_results(task)

def _initialize_task(self):
return MatTextTask(task_name=self.task, is_classification=True)

def _get_finetuner(self, exp_cfg, local_rank, fold_name):
return FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name)

Expand Down

0 comments on commit 624dd4c

Please sign in to comment.