Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
thevasudevgupta committed Jun 25, 2021
1 parent 0b096eb commit 9c2b86a
Show file tree
Hide file tree
Showing 10 changed files with 621 additions and 507 deletions.
11 changes: 11 additions & 0 deletions examples/research_projects/jax-projects/big_bird/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

```shell
pip3 install -qr requirements.txt
```

```shell
mkdir natural-questions-validation
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow -P natural-questions-validation
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/dataset_info.json -P natural-questions-validation
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/state.json -P natural-questions-validation
```
348 changes: 348 additions & 0 deletions examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
import json
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable

import wandb
import flax.linen as nn
import jax
import jax.numpy as jnp
import joblib
import optax
from flax import traverse_util, struct, jax_utils
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm

from transformers import (
BigBirdConfig,
FlaxBigBirdForQuestionAnswering
)
from transformers.models.big_bird.modeling_flax_big_bird import \
FlaxBigBirdForQuestionAnsweringModule


class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
"""
BigBirdForQuestionAnswering with CLS Head over the top for predicting category
This way we can load its weights with FlaxBigBirdForQuestionAnswering
"""

config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
add_pooling_layer: bool = True

def setup(self):
super().setup()
self.cls = nn.Dense(5, dtype=self.dtype)

def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
cls_out = self.cls(outputs[2])
return outputs[:2] + (cls_out, )


class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
module_class = FlaxBigBirdForNaturalQuestionsModule


def calculate_loss_for_nq(
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels
):
def cross_entropy(logits, labels, reduction=None):
"""
Args:
logits: bsz, seqlen, vocab_size
labels: bsz, seqlen
"""
vocab_size = logits.shape[-1]
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
logits = jax.nn.log_softmax(logits, axis=-1)
loss = -jnp.sum(labels * logits, axis=-1)
if reduction is not None:
loss = reduction(loss)
return loss

cross_entropy = partial(cross_entropy, reduction=jnp.mean)
start_loss = cross_entropy(start_logits, start_labels)
end_loss = cross_entropy(end_logits, end_labels)
pooled_loss = cross_entropy(pooled_logits, pooler_labels)
return (start_loss + end_loss + pooled_loss) / 3


@dataclass
class Args:
model_id: str = "google/bigbird-roberta-base"
logging_steps: int = 3000
save_steps: int = 10500

block_size: int = 128
num_random_blocks: int = 3

batch_size_per_device: int = 1
gradient_accumulation_steps: int = None # it's not implemented currently
max_epochs: int = 5

# tx_args
lr: float = 3e-5
init_lr: float = 0.0
warmup_steps: int = 20000
weight_decay: float = 0.0095

save_dir: str = "bigbird-roberta-natural-questions"
base_dir: str = "training-expt"
tr_data_path: str = "data/nq-training.jsonl"
val_data_path: str = "data/nq-validation.jsonl"

def __post_init__(self):
os.makedirs(self.base_dir, exist_ok=True)
self.save_dir = os.path.join(self.base_dir, self.save_dir)
self.batch_size = self.batch_size_per_device * jax.device_count()


@dataclass
class DataCollator:

pad_id: int
max_length: int = 4096 # no dynamic padding on TPUs

def __call__(self, batch):
batch = self.collate_fn(batch)
batch = jax.tree_map(shard, batch)
return batch

def collate_fn(self, features):
input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
batch = {
"input_ids": jnp.array(input_ids, dtype=jnp.int32),
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
}
return batch

def fetch_inputs(self, input_ids: list):
inputs = [self._fetch_inputs(ids) for ids in input_ids]
return zip(*inputs)

def _fetch_inputs(self, input_ids: list):
attention_mask = [1 for _ in range(len(input_ids))]
while len(input_ids) < self.max_length:
input_ids.append(self.pad_id)
attention_mask.append(0)
return input_ids, attention_mask


def get_batched_dataset(dataset, batch_size, seed=None):
if seed is not None:
dataset = dataset.shuffle(seed=seed)
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
yield dict(batch)


@partial(jax.pmap, axis_name="batch")
def train_step(state, drp_rng, **model_inputs):

def loss_fn(params):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")

outputs = state.apply_fn(
**model_inputs, params=params, dropout_rng=drp_rng, train=True
)
start_logits, end_logits, pooled_logits = outputs

return state.loss_fn(
start_logits,
start_labels,
end_logits,
end_labels,
pooled_logits,
pooled_labels,
)

drp_rng, new_drp_rng = jax.random.split(drp_rng)
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
grads = jax.lax.pmean(grads, "batch")

state = state.apply_gradients(grads=grads)
return state, metrics, new_drp_rng


@partial(jax.pmap, axis_name="batch")
def val_step(state, **model_inputs):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")

outputs = state.apply_fn(
**model_inputs, params=state.params, train=False
)
start_logits, end_logits, pooled_logits = outputs

loss = state.loss_fn(
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels
)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
return metrics


class TrainState(train_state.TrainState):
gradient_accumulation_steps: int = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False)


@dataclass
class Trainer:
args: Args
data_collator: Callable
train_step_fn: Callable
val_step_fn: Callable
model_save_fn: Callable
logger: wandb
scheduler_fn: Callable = None

def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
params = model.params
state = TrainState.create(
apply_fn=model.__call__,
params=params,
tx=tx,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
loss_fn=calculate_loss_for_nq,
)
if ckpt_dir is not None:
params, opt_state, step, args, data_collator = restore_checkpoint(
ckpt_dir, state
)
tx_args = {
"lr": args.lr,
"init_lr": args.init_lr,
"warmup_steps": args.warmup_steps,
"num_train_steps": num_train_steps,
"weight_decay": args.weight_decay,
}
tx, lr = build_tx(**tx_args)
state = train_state.TrainState(
step=step,
apply_fn=model.__call__,
params=params,
tx=tx,
opt_state=opt_state,
)
self.args = args
self.data_collator = data_collator
self.scheduler_fn = lr
model.params = params
state = jax_utils.replicate(state)
return state

def train(self, state, tr_dataset, val_dataset):
args = self.args
total = len(tr_dataset) // args.batch_size

rng = jax.random.PRNGKey(0)
drp_rng = jax.random.split(rng, jax.device_count())
for epoch in range(args.max_epochs):
running_loss = jnp.array(0, dtype=jnp.float32)
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
i = 0
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
batch = self.data_collator(batch)
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
running_loss += jax_utils.unreplicate(metrics["loss"])
i += 1
if i % args.logging_steps == 0:
state_step = jax_utils.unreplicate(state.step)
tr_loss = running_loss.item() / i
lr = self.scheduler_fn(state_step - 1)

eval_loss = self.evaluate(state, val_dataset)
logging_dict = dict(step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item())
tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True)

if i % args.save_steps == 0:
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)

def evaluate(self, state, dataset):
dataloader = get_batched_dataset(dataset, self.args.batch_size)
total = len(dataset) // self.args.batch_size
running_loss = jnp.array(0, dtype=jnp.float32)
i = 0
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
batch = self.data_collator(batch)
metrics = self.val_step_fn(state, **batch)
running_loss += jax_utils.unreplicate(metrics["loss"])
i += 1
return running_loss / i

def save_checkpoint(self, save_dir, state):
state = jax_utils.unreplicate(state)
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
self.model_save_fn(save_dir, params=state.params)
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
f.write(to_bytes(state.opt_state))
joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump({"step": state.step.item()}, f)
print("DONE")


def restore_checkpoint(save_dir, state):
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
params = from_bytes(state.params, f.read())

with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
opt_state = from_bytes(state.opt_state, f.read())

args = joblib.load(os.path.join(save_dir, "args.joblib"))
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))

with open(os.path.join(save_dir, "training_state.json"), "r") as f:
training_state = json.load(f)
step = training_state["step"]

print("DONE")
return params, opt_state, step, args, data_collator


def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
decay_steps = num_train_steps - warmup_steps
warmup_fn = optax.linear_schedule(
init_value=init_lr, end_value=lr, transition_steps=warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=lr, end_value=1e-7, transition_steps=decay_steps
)
lr = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]
)
return lr


def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
def weight_decay_mask(params):
params = traverse_util.flatten_dict(params)
mask = {
k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale"))
for k, v in params.items()
}
return traverse_util.unflatten_dict(mask)

lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)

tx = optax.adamw(
learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask
)
return tx, lr
Loading

0 comments on commit 9c2b86a

Please sign in to comment.