Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
thevasudevgupta committed Jun 25, 2021
1 parent 9c2b86a commit 579a488
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 124 deletions.
67 changes: 20 additions & 47 deletions examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@
from functools import partial
from typing import Callable

import wandb
from tqdm.auto import tqdm

import flax.linen as nn
import jax
import jax.numpy as jnp
import joblib
import optax
from flax import traverse_util, struct, jax_utils
import wandb
from flax import jax_utils, struct, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm

from transformers import (
BigBirdConfig,
FlaxBigBirdForQuestionAnswering
)
from transformers.models.big_bird.modeling_flax_big_bird import \
FlaxBigBirdForQuestionAnsweringModule
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule


class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
Expand All @@ -42,16 +38,14 @@ def setup(self):
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
cls_out = self.cls(outputs[2])
return outputs[:2] + (cls_out, )
return outputs[:2] + (cls_out,)


class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
module_class = FlaxBigBirdForNaturalQuestionsModule


def calculate_loss_for_nq(
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels
):
def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
def cross_entropy(logits, labels, reduction=None):
"""
Args:
Expand Down Expand Up @@ -83,7 +77,6 @@ class Args:
num_random_blocks: int = 3

batch_size_per_device: int = 1
gradient_accumulation_steps: int = None # it's not implemented currently
max_epochs: int = 5

# tx_args
Expand Down Expand Up @@ -147,15 +140,12 @@ def get_batched_dataset(dataset, batch_size, seed=None):

@partial(jax.pmap, axis_name="batch")
def train_step(state, drp_rng, **model_inputs):

def loss_fn(params):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")

outputs = state.apply_fn(
**model_inputs, params=params, dropout_rng=drp_rng, train=True
)
outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
start_logits, end_logits, pooled_logits = outputs

return state.loss_fn(
Expand Down Expand Up @@ -183,20 +173,15 @@ def val_step(state, **model_inputs):
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")

outputs = state.apply_fn(
**model_inputs, params=state.params, train=False
)
outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
start_logits, end_logits, pooled_logits = outputs

loss = state.loss_fn(
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels
)
loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
return metrics


class TrainState(train_state.TrainState):
gradient_accumulation_steps: int = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False)


Expand All @@ -216,13 +201,10 @@ def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
apply_fn=model.__call__,
params=params,
tx=tx,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
loss_fn=calculate_loss_for_nq,
)
if ckpt_dir is not None:
params, opt_state, step, args, data_collator = restore_checkpoint(
ckpt_dir, state
)
params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
tx_args = {
"lr": args.lr,
"init_lr": args.init_lr,
Expand Down Expand Up @@ -266,7 +248,9 @@ def train(self, state, tr_dataset, val_dataset):
lr = self.scheduler_fn(state_step - 1)

eval_loss = self.evaluate(state, val_dataset)
logging_dict = dict(step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item())
logging_dict = dict(
step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
)
tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True)

Expand Down Expand Up @@ -319,30 +303,19 @@ def restore_checkpoint(save_dir, state):

def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
decay_steps = num_train_steps - warmup_steps
warmup_fn = optax.linear_schedule(
init_value=init_lr, end_value=lr, transition_steps=warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=lr, end_value=1e-7, transition_steps=decay_steps
)
lr = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]
)
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
return lr


def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
def weight_decay_mask(params):
params = traverse_util.flatten_dict(params)
mask = {
k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale"))
for k, v in params.items()
}
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
return traverse_util.unflatten_dict(mask)

lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)

tx = optax.adamw(
learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask
)
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
return tx, lr
110 changes: 69 additions & 41 deletions examples/research_projects/jax-projects/big_bird/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
from datasets import load_from_disk

import jax
import jax.numpy as jnp
from bigbird_flax import FlaxBigBirdForNaturalQuestions
from transformers import BigBirdTokenizerFast
from datasets import load_from_disk


CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"}
PUNCTUATION_SET_TO_EXCLUDE = set(''.join(['‘', '’', '´', '`', '.', ',', '-', '"']))
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))


def get_sub_answers(answers, begin=0, end=None):
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]


def expand_to_aliases(given_answers, make_sub_answers=False):
if make_sub_answers:
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
given_answers = given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
answers = []
for answer in given_answers:
alias = answer.replace('_', ' ').lower()
alias = ''.join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else ' ' for c in alias)
answers.append(' '.join(alias.split()).strip())
return set(answers)
if make_sub_answers:
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
given_answers = (
given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
)
answers = []
for answer in given_answers:
alias = answer.replace("_", " ").lower()
alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias)
answers.append(" ".join(alias.split()).strip())
return set(answers)


def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100):
best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k)
Expand All @@ -34,47 +40,60 @@ def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100

return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k]


def format_dataset(sample):
question = sample['question']['text']
context = sample['document']['tokens']['token']
is_html = sample['document']['tokens']['is_html']
long_answers = sample['annotations']['long_answer']
short_answers = sample['annotations']['short_answers']
question = sample["question"]["text"]
context = sample["document"]["tokens"]["token"]
is_html = sample["document"]["tokens"]["is_html"]
long_answers = sample["annotations"]["long_answer"]
short_answers = sample["annotations"]["short_answers"]

context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]])
context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]])

# 0 - No ; 1 - Yes
for answer in sample['annotations']['yes_no_answer']:
for answer in sample["annotations"]["yes_no_answer"]:
if answer == 0 or answer == 1:
return {"question": question, "context": context_string, "short": [], "long": [], "category": "no" if answer == 0 else "yes"}
return {
"question": question,
"context": context_string,
"short": [],
"long": [],
"category": "no" if answer == 0 else "yes",
}

short_targets = []
for s in short_answers:
short_targets.extend(s['text'])
short_targets.extend(s["text"])
short_targets = list(set(short_targets))

long_targets = []
for s in long_answers:
if s['start_token'] == -1:
if s["start_token"] == -1:
continue
answer = context[s['start_token']: s['end_token']]
html = is_html[s['start_token']: s['end_token']]
answer = context[s["start_token"] : s["end_token"]]
html = is_html[s["start_token"] : s["end_token"]]
new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
if new_answer not in long_targets:
long_targets.append(new_answer)

category = "long_short" if len(short_targets + long_targets) > 0 else "null"

return {"question": question, "context": context_string, "short": short_targets, "long": long_targets, "category": category}
return {
"question": question,
"context": context_string,
"short": short_targets,
"long": long_targets,
"category": category,
}


def main():
dataset = load_from_disk("natural-questions-validation")
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"])
print(dataset)

short_validation_dataset = dataset.filter(lambda x: (len(x['question']) + len(x['context'])) < 4 * 4096)
short_validation_dataset = short_validation_dataset.filter(lambda x: x['category'] != "null")
short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096)
short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null")
short_validation_dataset

model_id = "vasudevgupta/flax-bigbird-natural-questions"
Expand All @@ -88,29 +107,38 @@ def forward(*args, **kwargs):

def evaluate(example):
# encode question and context so that they are seperated by a tokenizer.sep_token and cut at max_length
inputs = tokenizer(example["question"], example["context"], return_tensors="jax", max_length=4096, padding="max_length", truncation=True)
inputs = tokenizer(
example["question"],
example["context"],
return_tensors="jax",
max_length=4096,
padding="max_length",
truncation=True,
)

start_scores, end_scores, category = forward(**inputs)

predicted_category = CATEGORY_MAPPING[category.item()]

example['targets'] = example['long'] + example['short']
if example['category'] in ['yes', 'no', 'null']:
example['targets'] = [example['category']]
example['has_tgt'] = example['category'] != 'null'
example["targets"] = example["long"] + example["short"]
if example["category"] in ["yes", "no", "null"]:
example["targets"] = [example["category"]]
example["has_tgt"] = example["category"] != "null"
# Now target can be: "yes", "no", "null", "list of long & short answers"

if predicted_category in ['yes', 'no', 'null']:
example['output'] = [predicted_category]
example['match'] = example['output'] == example['targets']
example['has_pred'] = predicted_category != 'null'
if predicted_category in ["yes", "no", "null"]:
example["output"] = [predicted_category]
example["match"] = example["output"] == example["targets"]
example["has_pred"] = predicted_category != "null"
return example

max_size = 38 if predicted_category == "short" else 1024
start_score, end_score = get_best_valid_start_end_idx(start_scores[0], end_scores[0], top_k=8, max_size=max_size)
start_score, end_score = get_best_valid_start_end_idx(
start_scores[0], end_scores[0], top_k=8, max_size=max_size
)

input_ids = inputs["input_ids"][0].tolist()
example["output"] = [tokenizer.decode(input_ids[start_score: end_score+1])]
example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])]

answers = expand_to_aliases(example["targets"], make_sub_answers=True)
predictions = expand_to_aliases(example["output"])
Expand All @@ -122,15 +150,15 @@ def evaluate(example):

# if there is a common element, it's a exact match
example["match"] = len(list(answers & predictions)) > 0
example["has_pred"] = predicted_category != 'null' and len(predictions) > 0
example["has_pred"] = predicted_category != "null" and len(predictions) > 0

return example

short_validation_dataset = short_validation_dataset.map(evaluate)

total = len(short_validation_dataset)
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1))
print("EM score:", (matched / total)*100, "%")
print("EM score:", (matched / total) * 100, "%")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 579a488

Please sign in to comment.