Skip to content

Commit

Permalink
save progress
Browse files Browse the repository at this point in the history
  • Loading branch information
thevasudevgupta committed Jun 21, 2021
1 parent 613a224 commit 65beac1
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
datasets
flax
optax
jsonlines
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a question answering task")

parser.add_argument(
"--train_file", type=str, default=None, help="A json file containing the tokenized training data."
"--train_file", type=str, default=None, help="A jsonl file containing the tokenized training data."
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A json file containing the tokenized validation data."
"--validation_file", type=str, default=None, help="A jsonl file containing the tokenized validation data."
)
parser.add_argument(
"--model_name_or_path",
Expand All @@ -138,39 +138,55 @@ def parse_args():
required=True,
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
"--per_device_train_batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_train_batch_size",
"--gradient_accumulation_steps",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
default=1,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
"--lr1",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--lr2",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--block_size",
type=int,
default=84,
help="No. of tokens in each block",
)
parser.add_argument(
"--num_random_blocks",
type=int,
default=3,
help="No. of random blocks",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument("--num_train_epochs", type=int, default=5, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
args = parser.parse_args()
Expand All @@ -181,10 +197,10 @@ def parse_args():
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["json"], "`train_file` should be a json file."
assert extension in ["jsonl", "json"], "`train_file` should be a json/jsonl file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["json"], "`validation_file` should be a json file."
assert extension in ["jsonl", "json"], "`validation_file` should be a json/jsonl file."

if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
Expand Down Expand Up @@ -241,16 +257,26 @@ def eval_data_collator(dataset: Dataset, batch_size: int, pad_id: int, max_lengt


def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
train_ds_size: int,
train_batch_size: int,
num_train_epochs: int,
lr1: float,
lr2: float,
) -> Callable[[int], jnp.array]:
"""Returns a linear warmup, linear_decay learning rate function."""

steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps

# 30% of time, train with lr1
# rest of time, train with lr2
transition_steps = int(num_train_steps * 0.3)

lr1 = optax.linear_schedule(init_value=lr1, end_value=lr1, transition_steps=transition_steps)
lr2 = optax.linear_schedule(
init_value=lr2, end_value=lr2, transition_steps=num_train_steps - transition_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
schedule_fn = optax.join_schedules(schedules=[lr1, lr2], boundaries=[transition_steps])
return schedule_fn


Expand All @@ -272,12 +298,16 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()

model = FlaxBigBirdForNaturalQuestions.from_pretrained(args.model_id)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
model = FlaxBigBirdForNaturalQuestions.from_pretrained(
args.model_name_or_path,
block_size=args.block_size,
num_random_blocks=args.num_random_blocks
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

# load dataset from files created using `prepare_natural_questions.py` script
train_dataset = load_dataset("json", data_files=args.train_file)["train"]
eval_dataset = load_dataset("json", data_files=args.validation_file)["train"]
train_dataset = load_dataset("json", data_files=args.train_file, split="train")
eval_dataset = load_dataset("json", data_files=args.validation_file, split="train")

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
Expand Down Expand Up @@ -307,7 +337,7 @@ def write_metric(train_metrics, eval_metrics, train_time, step):
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()

learning_rate_fn = create_learning_rate_fn(
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
len(train_dataset), train_batch_size, args.num_train_epochs, args.lr1, args.lr2
)

state = create_train_state(
Expand Down Expand Up @@ -365,18 +395,15 @@ def eval_step(state, batch):
loss = state.loss_fn(
start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels
)
# TODO We should have following line in glue???
# metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
return loss
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
return metrics

p_eval_step = jax.pmap(eval_step, axis_name="batch")

# TODO: setup metric

logger.info("===== Starting training ({num_epochs} epochs) =====")
train_time = 0

# make sure weights are replicated on each device
# make sure state (params + opt_state) is replicated on each device
state = replicate(state)

for epoch in range(1, num_epochs + 1):
Expand All @@ -389,19 +416,19 @@ def eval_step(state, batch):

# train
for batch in train_data_collator(input_rng, train_dataset, train_batch_size, tokenizer.pad_token_id, max_length=4096):
# batch = self.data_collator(batch)
state, metrics, dropout_rng = p_train_step(state, batch, dropout_rng)
train_metrics.append(metrics)
train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")

logger.info(" Evaluating...")

losses = []

# evaluate
for batch in eval_data_collator(eval_dataset, eval_batch_size, tokenizer.pad_token_id, max_length=4096):
# batch = self.data_collator(batch)
metrics = p_eval_step(state, batch)
# TODO: collect metric
metric = p_eval_step(state, batch)
losses.append(unreplicate(metric)["loss"])

# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
Expand All @@ -412,11 +439,11 @@ def eval_step(state, batch):
batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()}

metrics = eval_step(state, batch)
# TODO: collect metric
metric = eval_step(state, batch)
losses.append(unreplicate(metric)["loss"])

eval_metric = "<do something>" # TODO
logger.info(f" Done! Eval metrics: {eval_metric}")
eval_metric = {"loss": losses}
logger.info(f" Done! Eval metrics: {unreplicate(eval_metric)}")

cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(train_metrics, eval_metric, train_time, cur_step)
Expand Down
10 changes: 10 additions & 0 deletions examples/research_projects/question-answering/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python3 run_question_answering_flax.py \
--train_file= \
--validation_file= \
--model_name_or_path=google/bigbird-roberta-base \
--per_device_train_batch_size=1 \
--per_device_eval_batch_size=2 \
--gradient_accumulation_steps=8 \
--lr1=5.e-5 \
--lr2=1.e-4 \
--block_size=128

0 comments on commit 65beac1

Please sign in to comment.