Skip to content

Commit

Permalink
Add caching of datasets in DataSilo (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni authored Dec 19, 2019
1 parent 44f05c7 commit fc824ff
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 3 deletions.
79 changes: 76 additions & 3 deletions farm/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import ExitStack
from functools import partial
import random
from pathlib import Path

import numpy as np
from sklearn.utils.class_weight import compute_class_weight
Expand All @@ -20,6 +21,7 @@
from farm.data_handler.utils import grouper
from farm.utils import MLFlowLogger as MlLogger
from farm.utils import log_ascii_workers, calc_chunksize
from farm.utils import get_dict_checksum
from farm.visual.ascii.images import TRACTOR_SMALL

logger = logging.getLogger(__name__)
Expand All @@ -39,6 +41,7 @@ def __init__(
automatic_loading=True,
max_multiprocessing_chunksize=2000,
max_processes=128,
checkpointing=False,
):
"""
:param processor: A dataset specific Processor object which will turn input (file or dict) into a Pytorch Dataset.
Expand All @@ -62,9 +65,20 @@ def __init__(
self.class_weights = None
self.max_processes = max_processes
self.max_multiprocessing_chunksize = max_multiprocessing_chunksize
# In most cases we want to load all data automatically, but in some cases we rather want to do this later or
# load from dicts instead of file (https://github.com/deepset-ai/FARM/issues/85)
if automatic_loading:

loaded_from_cache = False
if checkpointing: # Check if DataSets are present in cache
checksum = self._get_checksum()
dataset_path = Path(f"cache/data_silo/{checksum}")

if dataset_path.exists():
logger.info("Loading datasets from cache ...")
self._load_dataset_from_cache(dataset_path)
loaded_from_cache = True

if not loaded_from_cache and automatic_loading:
# In most cases we want to load all data automatically, but in some cases we rather want to do this
# later or load from dicts instead of file (https://github.com/deepset-ai/FARM/issues/85)
self._load_data()

@classmethod
Expand Down Expand Up @@ -196,11 +210,70 @@ def _load_data(self, train_dicts=None, dev_dicts=None, test_dicts=None):
logger.info("No test set is being loaded")
self.data["test"] = None

self._save_dataset_to_cache()

# derive stats and meta data
self._calculate_statistics()
# self.calculate_class_weights()

self._initialize_data_loaders()

def _get_checksum(self):
"""
Get checksum based on a dict to ensure validity of cached DataSilo
"""
# keys in the dict identifies uniqueness for a given DataSilo.
payload_dict = {
"train_filename": str(Path(self.processor.train_filename).absolute())
}
checksum = get_dict_checksum(payload_dict)
return checksum

def _load_dataset_from_cache(self, cache_dir):
"""
Load serialized dataset from a cache.
"""
self.data["train"] = torch.load(cache_dir / "train_dataset")

dev_dataset_path = cache_dir / "dev_dataset"
if dev_dataset_path.exists():
self.data["dev"] = torch.load(dev_dataset_path)
else:
self.data["dev"] = None

test_dataset_path = cache_dir / "test_dataset"
if test_dataset_path.exists():
self.data["test"] = torch.load(test_dataset_path)
else:
self.data["test"] = None

self.tensor_names = torch.load(cache_dir / "tensor_names")

# derive stats and meta data
self._calculate_statistics()
# self.calculate_class_weights()

self._initialize_data_loaders()

def _save_dataset_to_cache(self):
"""
Serialize and save dataset to a cache.
"""
checksum = self._get_checksum()

cache_dir = Path(f"cache/data_silo/{checksum}")
cache_dir.mkdir(parents=True, exist_ok=True)

torch.save(self.data["train"], cache_dir / "train_dataset")

if self.data["dev"]:
torch.save(self.data["dev"], cache_dir / "dev_dataset")

if self.data["test"]:
torch.save(self.data["test"], cache_dir / "test_dataset")

torch.save(self.tensor_names, cache_dir / "tensor_names")

def _initialize_data_loaders(self):
""" Initializing train, dev and test data loaders for the already loaded datasets """

Expand Down
9 changes: 9 additions & 0 deletions farm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import logging
import random

Expand Down Expand Up @@ -243,3 +245,10 @@ def decode_squad_id(part_1, part_2):
assert len(hexa) == 24
return hexa


def get_dict_checksum(payload_dict):
"""
Get MD5 checksum for a dict.
"""
checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest()
return checksum

0 comments on commit fc824ff

Please sign in to comment.