Skip to content

Commit

Permalink
Start of merge
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Aug 2, 2022
1 parent 2afa61e commit ca65276
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/accelerate/test_utils/scripts/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,100 @@ def test_torch_metrics(accelerator: Accelerator, num_samples=82):
), f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(inps)}"


import math

import torch
from torch.utils.data import DataLoader

import evaluate
from accelerate import Accelerator
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def get_dataloader(accelerator: Accelerator, drop_last=False):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased")
dataset = load_dataset("glue", "mrpc", split="validation")

def tokenize_function(examples):
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs

with accelerator.main_process_first():
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)

tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

def collate_fn(examples):
return tokenizer.pad(examples, padding="longest", return_tensors="pt")

return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16, drop_last=drop_last)


def get_setup(dispatch_batches, split_batches, drop_last):
accelerator = Accelerator(dispatch_batches=dispatch_batches, split_batches=split_batches)
dataloader = get_dataloader(accelerator, drop_last)
model = AutoModelForSequenceClassification.from_pretrained("hf-internal-testing/mrpc-bert-base-cased", return_dict=True)
ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)
return {"ddp": [ddp_model, ddp_dataloader, "cuda:0"], "no": [model, dataloader, accelerator.device]}, accelerator


def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
drop_last = False if not dispatch_batches else True
metric = evaluate.load("glue", "mrpc")
setup, accelerator = get_setup(dispatch_batches, split_batches, drop_last)
# First do baseline
if accelerator.is_local_main_process:
print("Running baseline")
model, dataloader, device = setup["no"]
if accelerator.is_local_main_process:
print(f"Len dl: {len(dataloader)}\nLen dset: {len(dataloader.dataset)}\n")
model.to(device)
model.eval()
for batch in dataloader:
batch.to(device)
with torch.inference_mode():
outputs = model(**batch)
preds = outputs.logits.argmax(dim=-1)
metric.add_batch(predictions=preds, references=batch["labels"])
baseline = metric.compute()

# Then do distributed
if accelerator.is_local_main_process:
print("Running with Gradient State")
model, dataloader, device = setup["ddp"]
model.eval()
for batch in dataloader:
with torch.inference_mode():
outputs = model(**batch)
preds = outputs.logits.argmax(dim=-1)
references = batch["labels"]
preds, references = accelerator.gather_for_metrics((preds, references))
metric.add_batch(predictions=preds, references=references)
distributed = metric.compute()

for key in "accuracy f1".split():
if not math.isclose(baseline[key], distributed[key]) and accelerator.is_local_main_process:
print(
f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n"
)



def main():
accelerator = Accelerator(split_batches=False, dispatch_batches=False)
if accelerator.is_local_main_process:
print("**Testing gather_for_metrics**")
for split_batches in [True, False]:
for dispatch_batches in [True, False]:
if accelerator.is_local_main_process:
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
test_mrpc(split_batches, dispatch_batches)
accelerator.state._reset_state()
if accelerator.is_local_main_process:
print("**Test torch metrics**")
for split_batches in [True, False]:
Expand Down

0 comments on commit ca65276

Please sign in to comment.