Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train from scratch #170

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions examples/train_from_scratch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# fmt: off
import logging


from transformers.tokenization_bert import BertTokenizer

from farm.data_handler.data_silo import DataSilo
from farm.data_handler.processor import BertStyleLMProcessor
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.language_model import LanguageModel
from farm.modeling.optimization import initialize_optimizer
from farm.modeling.prediction_head import BertLMHead, NextSentenceHead
from farm.train import Trainer
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

ml_logger = MLFlowLogger(tracking_uri="")
ml_logger.init_experiment(experiment_name="from_scratch", run_name="debug")

#########################
######## Settings
########################
set_all_seeds(seed=39)
device, n_gpu = initialize_device_settings(use_cuda=True)
evaluate_every = 5000
vocab_size = 29999
# dev_filename = None
save_dir = "/opt/ml/model"
predictions_file = save_dir + "/predictions.json"
full_predictions_file = save_dir + "/full_predictions.json"
inference_multiprocessing = True

n_epochs = 10
learning_rate = 1e-4
warmup_proportion = 0.05
batch_size = 16 # (probably only possible via gradient accumulation steps)
max_seq_len = 64


# 1.Create a tokenizer
tokenizer = BertTokenizer("vocab.txt")

# # 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
# processor = BertStyleLMProcessor(
# # data_dir="/opt/ml/input/data/train/lm_finetune_nips",
# data_dir="/Users/tanay",
# tokenizer=tokenizer, max_seq_len=max_seq_len,
# train_filename="full_corpus.txt",
# dev_split=2000 / 8_000_000,
# dev_filename=None,
# test_filename=None,
# )
#
# # 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
# data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)

# 4. Create an AdaptiveModel
# a) which consists of a pretrained language model as a basis
language_model = LanguageModel.from_scratch("bert", vocab_size)

# b) and *two* prediction heads on top that are suited for our task => Language Model finetuning
lm_prediction_head = BertLMHead(768, vocab_size)
next_sentence_head = NextSentenceHead([768, 2], task_name="nextsentence")

model = AdaptiveModel(
language_model=language_model,
prediction_heads=[lm_prediction_head, next_sentence_head],
embeds_dropout_prob=0.1,
lm_output_types=["per_token", "per_sequence"],
device=device,)

# 5. Create an optimizer
optimizer, warmup_linear = initialize_optimizer(
model=model,
learning_rate=learning_rate,
warmup_proportion=warmup_proportion,
n_batches=len(data_silo.loaders["train"]),
n_epochs=n_epochs,
grad_acc_steps=8,
)
# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
trainer = Trainer(
optimizer=optimizer,
data_silo=data_silo,
epochs=n_epochs,
n_gpu=n_gpu,
warmup_linear=warmup_linear,
evaluate_every=evaluate_every,
device=device,
grad_acc_steps=8,
)
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
model = trainer.train(model)

# 8. Hooray! You have a model. Store it:
model.save(save_dir)
processor.save(save_dir)
138 changes: 116 additions & 22 deletions farm/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import ExitStack
from functools import partial
import random
from pathlib import Path

import numpy as np
from sklearn.utils.class_weight import compute_class_weight
Expand All @@ -20,6 +21,7 @@
from farm.data_handler.utils import grouper
from farm.utils import MLFlowLogger as MlLogger
from farm.utils import log_ascii_workers, calc_chunksize
from farm.utils import get_dict_checksum
from farm.visual.ascii.images import TRACTOR_SMALL

logger = logging.getLogger(__name__)
Expand All @@ -31,7 +33,16 @@ class DataSilo:
calculate and display some statistics.
"""

def __init__(self, processor, batch_size, distributed=False, automatic_loading=True, max_multiprocessing_chunksize=2000):
def __init__(
self,
processor,
batch_size,
distributed=False,
automatic_loading=True,
max_multiprocessing_chunksize=2000,
max_processes=128,
checkpointing=False,
):
"""
:param processor: A dataset specific Processor object which will turn input (file or dict) into a Pytorch Dataset.
:type processor: Processor
Expand All @@ -52,15 +63,26 @@ def __init__(self, processor, batch_size, distributed=False, automatic_loading=T
self.data = {}
self.batch_size = batch_size
self.class_weights = None
self.max_processes = 128
self.max_processes = max_processes
self.max_multiprocessing_chunksize = max_multiprocessing_chunksize
# In most cases we want to load all data automatically, but in some cases we rather want to do this later or
# load from dicts instead of file (https://github.com/deepset-ai/FARM/issues/85)
if automatic_loading:

loaded_from_cache = False
if checkpointing: # Check if DataSets are present in cache
checksum = self._get_checksum()
dataset_path = Path(f"cache/data_silo/{checksum}")

if dataset_path.exists():
logger.info("Loading datasets from cache ...")
self._load_dataset_from_cache(dataset_path)
loaded_from_cache = True

if not loaded_from_cache and automatic_loading:
# In most cases we want to load all data automatically, but in some cases we rather want to do this
# later or load from dicts instead of file (https://github.com/deepset-ai/FARM/issues/85)
self._load_data()

@classmethod
def _multiproc(cls, chunk, processor):
def _dataset_from_chunk(cls, chunk, processor):
"""
Creating a dataset for a chunk (= subset) of dicts. In multiprocessing:
* we read in all dicts from a file
Expand Down Expand Up @@ -93,29 +115,42 @@ def _get_dataset(self, filename, dicts=None):
random.shuffle(dicts)

num_dicts = len(dicts)
multiprocessing_chunk_size, num_cpus_used = calc_chunksize(num_dicts, max_chunksize=self.max_multiprocessing_chunksize)
multiprocessing_chunk_size, num_cpus_used = calc_chunksize(
num_dicts=num_dicts,
max_processes=self.max_processes,
max_chunksize=self.max_multiprocessing_chunksize,
)

with ExitStack() as stack:
p = stack.enter_context(mp.Pool(processes=num_cpus_used))

logger.info(
f"Got ya {num_cpus_used} parallel workers to convert {num_dicts} dictionaries "
f"to pytorch datasets (chunksize = {multiprocessing_chunk_size})..."
)
log_ascii_workers(num_cpus_used, logger)

results = p.imap(
partial(self._multiproc, processor=self.processor),
grouper(dicts, multiprocessing_chunk_size),
chunksize=1,
)
if self.max_processes > 1: # use multiprocessing only when max_processes > 1
p = stack.enter_context(mp.Pool(processes=num_cpus_used))

logger.info(
f"Got ya {num_cpus_used} parallel workers to convert {num_dicts} dictionaries "
f"to pytorch datasets (chunksize = {multiprocessing_chunk_size})..."
)
log_ascii_workers(num_cpus_used, logger)

results = p.imap(
partial(self._dataset_from_chunk, processor=self.processor),
grouper(dicts, multiprocessing_chunk_size),
chunksize=1,
)
else:
logger.info(
f"Multiprocessing disabled, using a single worker to convert {num_dicts}"
f"dictionaries to pytorch datasets."
)

results = map(partial(self._dataset_from_chunk, processor=self.processor), grouper(dicts, num_dicts))

datasets = []
with tqdm(total=len(dicts), unit=' Dicts') as pbar:

with tqdm(total=len(dicts), unit=' Dicts', desc="Preprocessing Dataset") as pbar:
for dataset, tensor_names in results:
datasets.append(dataset)
pbar.update(multiprocessing_chunk_size)

concat_datasets = ConcatDataset(datasets)
return concat_datasets, tensor_names

Expand Down Expand Up @@ -175,11 +210,70 @@ def _load_data(self, train_dicts=None, dev_dicts=None, test_dicts=None):
logger.info("No test set is being loaded")
self.data["test"] = None

self._save_dataset_to_cache()

# derive stats and meta data
self._calculate_statistics()
# self.calculate_class_weights()

self._initialize_data_loaders()

def _get_checksum(self):
"""
Get checksum based on a dict to ensure validity of cached DataSilo
"""
# keys in the dict identifies uniqueness for a given DataSilo.
payload_dict = {
"train_filename": str(Path(self.processor.train_filename).absolute())
}
checksum = get_dict_checksum(payload_dict)
return checksum

def _load_dataset_from_cache(self, cache_dir):
"""
Load serialized dataset from a cache.
"""
self.data["train"] = torch.load(cache_dir / "train_dataset")

dev_dataset_path = cache_dir / "dev_dataset"
if dev_dataset_path.exists():
self.data["dev"] = torch.load(dev_dataset_path)
else:
self.data["dev"] = None

test_dataset_path = cache_dir / "test_dataset"
if test_dataset_path.exists():
self.data["test"] = torch.load(test_dataset_path)
else:
self.data["test"] = None

self.tensor_names = torch.load(cache_dir / "tensor_names")

# derive stats and meta data
self._calculate_statistics()
# self.calculate_class_weights()

self._initialize_data_loaders()

def _save_dataset_to_cache(self):
"""
Serialize and save dataset to a cache.
"""
checksum = self._get_checksum()

cache_dir = Path(f"cache/data_silo/{checksum}")
cache_dir.mkdir(parents=True, exist_ok=True)

torch.save(self.data["train"], cache_dir / "train_dataset")

if self.data["dev"]:
torch.save(self.data["dev"], cache_dir / "dev_dataset")

if self.data["test"]:
torch.save(self.data["test"], cache_dir / "test_dataset")

torch.save(self.tensor_names, cache_dir / "tensor_names")

def _initialize_data_loaders(self):
""" Initializing train, dev and test data loaders for the already loaded datasets """

Expand Down
19 changes: 18 additions & 1 deletion farm/modeling/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def __init_subclass__(cls, **kwargs):
def forward(self, input_ids, padding_mask, **kwargs):
raise NotImplementedError

@classmethod
def from_scratch(cls, model_type, vocab_size):
if model_type.lower() == "bert":
model = Bert
return model.from_scratch(vocab_size)

@classmethod
def load(cls, pretrained_model_name_or_path, n_added_tokens=0, **kwargs):
"""
Expand Down Expand Up @@ -95,13 +101,15 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, **kwargs):
language_model = cls.subclasses["Bert"].load(pretrained_model_name_or_path, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
language_model = cls.subclasses["XLNet"].load(pretrained_model_name_or_path, **kwargs)
elif "albert" in pretrained_model_name_or_path:
language_model = cls.subclasses["Albert"].load(pretrained_model_name_or_path, **kwargs)
else:
language_model = None

if not language_model:
raise Exception(
f"Model not found for {pretrained_model_name_or_path}. Either supply the local path for a saved model "
f"or one of bert/roberta/xlnet models that can be downloaded from remote. Here's the list of available "
f"or one of bert/roberta/xlnet/albert models that can be downloaded from remote. Here's the list of available "
f"models: https://farm.deepset.ai/api/modeling.html#farm.modeling.language_model.LanguageModel.load"
)

Expand Down Expand Up @@ -246,6 +254,15 @@ def __init__(self):
self.model = None
self.name = "bert"

@classmethod
def from_scratch(cls, vocab_size, name="bert", language="en"):
bert = cls()
bert.name = name
bert.language = language
config = BertConfig(vocab_size=vocab_size)
bert.model = BertModel(config)
return bert

@classmethod
def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
"""
Expand Down
Loading