-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0b096eb
commit 9c2b86a
Showing
10 changed files
with
621 additions
and
507 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
examples/research_projects/jax-projects/big_bird/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
|
||
```shell | ||
pip3 install -qr requirements.txt | ||
``` | ||
|
||
```shell | ||
mkdir natural-questions-validation | ||
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow -P natural-questions-validation | ||
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/dataset_info.json -P natural-questions-validation | ||
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/state.json -P natural-questions-validation | ||
``` |
348 changes: 348 additions & 0 deletions
348
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
import json | ||
import os | ||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import Callable | ||
|
||
import wandb | ||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
import joblib | ||
import optax | ||
from flax import traverse_util, struct, jax_utils | ||
from flax.serialization import from_bytes, to_bytes | ||
from flax.training import train_state | ||
from flax.training.common_utils import shard | ||
from tqdm.auto import tqdm | ||
|
||
from transformers import ( | ||
BigBirdConfig, | ||
FlaxBigBirdForQuestionAnswering | ||
) | ||
from transformers.models.big_bird.modeling_flax_big_bird import \ | ||
FlaxBigBirdForQuestionAnsweringModule | ||
|
||
|
||
class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule): | ||
""" | ||
BigBirdForQuestionAnswering with CLS Head over the top for predicting category | ||
This way we can load its weights with FlaxBigBirdForQuestionAnswering | ||
""" | ||
|
||
config: BigBirdConfig | ||
dtype: jnp.dtype = jnp.float32 | ||
add_pooling_layer: bool = True | ||
|
||
def setup(self): | ||
super().setup() | ||
self.cls = nn.Dense(5, dtype=self.dtype) | ||
|
||
def __call__(self, *args, **kwargs): | ||
outputs = super().__call__(*args, **kwargs) | ||
cls_out = self.cls(outputs[2]) | ||
return outputs[:2] + (cls_out, ) | ||
|
||
|
||
class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering): | ||
module_class = FlaxBigBirdForNaturalQuestionsModule | ||
|
||
|
||
def calculate_loss_for_nq( | ||
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels | ||
): | ||
def cross_entropy(logits, labels, reduction=None): | ||
""" | ||
Args: | ||
logits: bsz, seqlen, vocab_size | ||
labels: bsz, seqlen | ||
""" | ||
vocab_size = logits.shape[-1] | ||
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4") | ||
logits = jax.nn.log_softmax(logits, axis=-1) | ||
loss = -jnp.sum(labels * logits, axis=-1) | ||
if reduction is not None: | ||
loss = reduction(loss) | ||
return loss | ||
|
||
cross_entropy = partial(cross_entropy, reduction=jnp.mean) | ||
start_loss = cross_entropy(start_logits, start_labels) | ||
end_loss = cross_entropy(end_logits, end_labels) | ||
pooled_loss = cross_entropy(pooled_logits, pooler_labels) | ||
return (start_loss + end_loss + pooled_loss) / 3 | ||
|
||
|
||
@dataclass | ||
class Args: | ||
model_id: str = "google/bigbird-roberta-base" | ||
logging_steps: int = 3000 | ||
save_steps: int = 10500 | ||
|
||
block_size: int = 128 | ||
num_random_blocks: int = 3 | ||
|
||
batch_size_per_device: int = 1 | ||
gradient_accumulation_steps: int = None # it's not implemented currently | ||
max_epochs: int = 5 | ||
|
||
# tx_args | ||
lr: float = 3e-5 | ||
init_lr: float = 0.0 | ||
warmup_steps: int = 20000 | ||
weight_decay: float = 0.0095 | ||
|
||
save_dir: str = "bigbird-roberta-natural-questions" | ||
base_dir: str = "training-expt" | ||
tr_data_path: str = "data/nq-training.jsonl" | ||
val_data_path: str = "data/nq-validation.jsonl" | ||
|
||
def __post_init__(self): | ||
os.makedirs(self.base_dir, exist_ok=True) | ||
self.save_dir = os.path.join(self.base_dir, self.save_dir) | ||
self.batch_size = self.batch_size_per_device * jax.device_count() | ||
|
||
|
||
@dataclass | ||
class DataCollator: | ||
|
||
pad_id: int | ||
max_length: int = 4096 # no dynamic padding on TPUs | ||
|
||
def __call__(self, batch): | ||
batch = self.collate_fn(batch) | ||
batch = jax.tree_map(shard, batch) | ||
return batch | ||
|
||
def collate_fn(self, features): | ||
input_ids, attention_mask = self.fetch_inputs(features["input_ids"]) | ||
batch = { | ||
"input_ids": jnp.array(input_ids, dtype=jnp.int32), | ||
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32), | ||
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32), | ||
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32), | ||
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32), | ||
} | ||
return batch | ||
|
||
def fetch_inputs(self, input_ids: list): | ||
inputs = [self._fetch_inputs(ids) for ids in input_ids] | ||
return zip(*inputs) | ||
|
||
def _fetch_inputs(self, input_ids: list): | ||
attention_mask = [1 for _ in range(len(input_ids))] | ||
while len(input_ids) < self.max_length: | ||
input_ids.append(self.pad_id) | ||
attention_mask.append(0) | ||
return input_ids, attention_mask | ||
|
||
|
||
def get_batched_dataset(dataset, batch_size, seed=None): | ||
if seed is not None: | ||
dataset = dataset.shuffle(seed=seed) | ||
for i in range(len(dataset) // batch_size): | ||
batch = dataset[i * batch_size : (i + 1) * batch_size] | ||
yield dict(batch) | ||
|
||
|
||
@partial(jax.pmap, axis_name="batch") | ||
def train_step(state, drp_rng, **model_inputs): | ||
|
||
def loss_fn(params): | ||
start_labels = model_inputs.pop("start_labels") | ||
end_labels = model_inputs.pop("end_labels") | ||
pooled_labels = model_inputs.pop("pooled_labels") | ||
|
||
outputs = state.apply_fn( | ||
**model_inputs, params=params, dropout_rng=drp_rng, train=True | ||
) | ||
start_logits, end_logits, pooled_logits = outputs | ||
|
||
return state.loss_fn( | ||
start_logits, | ||
start_labels, | ||
end_logits, | ||
end_labels, | ||
pooled_logits, | ||
pooled_labels, | ||
) | ||
|
||
drp_rng, new_drp_rng = jax.random.split(drp_rng) | ||
grad_fn = jax.value_and_grad(loss_fn) | ||
loss, grads = grad_fn(state.params) | ||
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch") | ||
grads = jax.lax.pmean(grads, "batch") | ||
|
||
state = state.apply_gradients(grads=grads) | ||
return state, metrics, new_drp_rng | ||
|
||
|
||
@partial(jax.pmap, axis_name="batch") | ||
def val_step(state, **model_inputs): | ||
start_labels = model_inputs.pop("start_labels") | ||
end_labels = model_inputs.pop("end_labels") | ||
pooled_labels = model_inputs.pop("pooled_labels") | ||
|
||
outputs = state.apply_fn( | ||
**model_inputs, params=state.params, train=False | ||
) | ||
start_logits, end_logits, pooled_logits = outputs | ||
|
||
loss = state.loss_fn( | ||
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels | ||
) | ||
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch") | ||
return metrics | ||
|
||
|
||
class TrainState(train_state.TrainState): | ||
gradient_accumulation_steps: int = struct.field(pytree_node=False) | ||
loss_fn: Callable = struct.field(pytree_node=False) | ||
|
||
|
||
@dataclass | ||
class Trainer: | ||
args: Args | ||
data_collator: Callable | ||
train_step_fn: Callable | ||
val_step_fn: Callable | ||
model_save_fn: Callable | ||
logger: wandb | ||
scheduler_fn: Callable = None | ||
|
||
def create_state(self, model, tx, num_train_steps, ckpt_dir=None): | ||
params = model.params | ||
state = TrainState.create( | ||
apply_fn=model.__call__, | ||
params=params, | ||
tx=tx, | ||
gradient_accumulation_steps=self.args.gradient_accumulation_steps, | ||
loss_fn=calculate_loss_for_nq, | ||
) | ||
if ckpt_dir is not None: | ||
params, opt_state, step, args, data_collator = restore_checkpoint( | ||
ckpt_dir, state | ||
) | ||
tx_args = { | ||
"lr": args.lr, | ||
"init_lr": args.init_lr, | ||
"warmup_steps": args.warmup_steps, | ||
"num_train_steps": num_train_steps, | ||
"weight_decay": args.weight_decay, | ||
} | ||
tx, lr = build_tx(**tx_args) | ||
state = train_state.TrainState( | ||
step=step, | ||
apply_fn=model.__call__, | ||
params=params, | ||
tx=tx, | ||
opt_state=opt_state, | ||
) | ||
self.args = args | ||
self.data_collator = data_collator | ||
self.scheduler_fn = lr | ||
model.params = params | ||
state = jax_utils.replicate(state) | ||
return state | ||
|
||
def train(self, state, tr_dataset, val_dataset): | ||
args = self.args | ||
total = len(tr_dataset) // args.batch_size | ||
|
||
rng = jax.random.PRNGKey(0) | ||
drp_rng = jax.random.split(rng, jax.device_count()) | ||
for epoch in range(args.max_epochs): | ||
running_loss = jnp.array(0, dtype=jnp.float32) | ||
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch) | ||
i = 0 | ||
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"): | ||
batch = self.data_collator(batch) | ||
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch) | ||
running_loss += jax_utils.unreplicate(metrics["loss"]) | ||
i += 1 | ||
if i % args.logging_steps == 0: | ||
state_step = jax_utils.unreplicate(state.step) | ||
tr_loss = running_loss.item() / i | ||
lr = self.scheduler_fn(state_step - 1) | ||
|
||
eval_loss = self.evaluate(state, val_dataset) | ||
logging_dict = dict(step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()) | ||
tqdm.write(str(logging_dict)) | ||
self.logger.log(logging_dict, commit=True) | ||
|
||
if i % args.save_steps == 0: | ||
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state) | ||
|
||
def evaluate(self, state, dataset): | ||
dataloader = get_batched_dataset(dataset, self.args.batch_size) | ||
total = len(dataset) // self.args.batch_size | ||
running_loss = jnp.array(0, dtype=jnp.float32) | ||
i = 0 | ||
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "): | ||
batch = self.data_collator(batch) | ||
metrics = self.val_step_fn(state, **batch) | ||
running_loss += jax_utils.unreplicate(metrics["loss"]) | ||
i += 1 | ||
return running_loss / i | ||
|
||
def save_checkpoint(self, save_dir, state): | ||
state = jax_utils.unreplicate(state) | ||
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ") | ||
self.model_save_fn(save_dir, params=state.params) | ||
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: | ||
f.write(to_bytes(state.opt_state)) | ||
joblib.dump(self.args, os.path.join(save_dir, "args.joblib")) | ||
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib")) | ||
with open(os.path.join(save_dir, "training_state.json"), "w") as f: | ||
json.dump({"step": state.step.item()}, f) | ||
print("DONE") | ||
|
||
|
||
def restore_checkpoint(save_dir, state): | ||
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ") | ||
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f: | ||
params = from_bytes(state.params, f.read()) | ||
|
||
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f: | ||
opt_state = from_bytes(state.opt_state, f.read()) | ||
|
||
args = joblib.load(os.path.join(save_dir, "args.joblib")) | ||
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib")) | ||
|
||
with open(os.path.join(save_dir, "training_state.json"), "r") as f: | ||
training_state = json.load(f) | ||
step = training_state["step"] | ||
|
||
print("DONE") | ||
return params, opt_state, step, args, data_collator | ||
|
||
|
||
def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps): | ||
decay_steps = num_train_steps - warmup_steps | ||
warmup_fn = optax.linear_schedule( | ||
init_value=init_lr, end_value=lr, transition_steps=warmup_steps | ||
) | ||
decay_fn = optax.linear_schedule( | ||
init_value=lr, end_value=1e-7, transition_steps=decay_steps | ||
) | ||
lr = optax.join_schedules( | ||
schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps] | ||
) | ||
return lr | ||
|
||
|
||
def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay): | ||
def weight_decay_mask(params): | ||
params = traverse_util.flatten_dict(params) | ||
mask = { | ||
k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) | ||
for k, v in params.items() | ||
} | ||
return traverse_util.unflatten_dict(mask) | ||
|
||
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps) | ||
|
||
tx = optax.adamw( | ||
learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask | ||
) | ||
return tx, lr |
Oops, something went wrong.