Skip to content

Commit

Permalink
[T104292598] Refactor the "LRA" training code -> Pytorch Lightning (#343
Browse files Browse the repository at this point in the history
)

* First attempt at PL trainer

* Blocksparse switch revisions (#342)

* minor cleanup; updated changelog

* fixed mypy error

* added checking for blocksparse availability

Co-authored-by: Chris Yuan <christopheryuan@learnfair1490.h2.fair>
Co-authored-by: Chris Yuan <christopheryuan@devfair0278.h2.fair>

* Finish PL refactor

* Fix coding style, remove unused imports

* Fix flake8 error

* Make isort happy

* Let pre-commit handle formatting...

* Add type hints, fix eval behavior

* Evaluate PL refactor with batch_submit.py

Co-authored-by: Chris Yuan <christopheryuan@learnfair1490.h2.fair>
Co-authored-by: Chris Yuan <christopheryuan@devfair0278.h2.fair>
  • Loading branch information
3 people authored Jul 8, 2022
1 parent a59281a commit 769cfe3
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 544 deletions.
44 changes: 13 additions & 31 deletions xformers/benchmarks/LRA/batch_fetch_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@
from pathlib import Path
from typing import Any, Dict

reference_steps = {
"image": 35176,
"listops": 10000,
"pathfinder32-curv_contour_length_14": 62400,
"pathfinder32-curv_baseline": 62400,
"pathfinder32-curv_contour_length_9": 62400,
"text": 20000,
"retrieval": 30000,
}

if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser(
Expand All @@ -38,36 +28,28 @@

for attention in filter(lambda x: x.is_dir(), root.iterdir()):
logging.info(f"\nFound results for {attention.stem}")
task_logs = attention.glob("*/*.log")
task_jsons = attention.glob("*/test_eval_summary.json")
results[attention.stem] = {}

for task in filter(lambda x: "__0" in str(x), task_logs):
for task in task_jsons:
task_name = task.stem.split("__")[0]
logging.info(f"Logs found for task: {task_name}")
results[attention.stem][task_name] = -1
found_result = False

# - collect the individual results
with open(task, "r") as result_file:
for line in reversed(result_file.readlines()):
if '"component": "test"' in line:
found_result = True

# Check that all the steps are done
res = json.loads(line)

if res["train_step_idx"] == reference_steps[task_name]:
results[attention.stem][task_name] = res["best_accu"]
logging.info(
f"Final result found for {task_name}: {results[attention.stem][task_name]}"
)
else:
logging.info(
"Current step: {}/{}. Not finished".format(
res["train_step_idx"], reference_steps[task_name]
)
)
break
dct = json.load(result_file)
if "test_accu_mean" in dct:
found_result = True
results[attention.stem][task_name] = dct["test_accu_mean"]

logging.info(
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
f"{results[attention.stem][task_name]}"
)
else:
break

# - report an error if no result was found
if not found_result:
Expand Down
9 changes: 0 additions & 9 deletions xformers/benchmarks/LRA/batch_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ def get_default_shared_folder() -> str:
parser.add_argument(
"--partition", default="a100", type=str, help="Partition where to submit"
)
parser.add_argument(
"-tb",
"--tb_path",
type=str,
help="Path to the tensorboard directory",
dest="tb_dir",
default=f"/{default_checkpoint_path}/{os.getenv('USER')}/xformers/tb",
)
args = parser.parse_args()

for attention in args.attentions:
Expand All @@ -54,5 +46,4 @@ def get_default_shared_folder() -> str:
+ f" --attention {attention} --task {task} --config {args.config_path}"
+ f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}"
+ f" --partition {args.partition}"
+ f" --tb_dir {args.tb_dir}/{attention}/{task}"
)
162 changes: 120 additions & 42 deletions xformers/benchmarks/LRA/code/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
# https://github.com/mlpen/Nystromformer

from enum import Enum
from typing import Dict, Union

import pytorch_lightning as pl
import torch
import torch.nn as nn

Expand All @@ -17,6 +19,8 @@
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
from xformers.utils import generate_matching_config

PLOutput = Dict[str, Union[float, torch.Tensor]]


class Pooling(str, Enum):
MEAN = "mean"
Expand Down Expand Up @@ -113,11 +117,12 @@ def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
return seq_score


class ModelTrunk(nn.Module):
class ModelTrunk(pl.LightningModule):
def __init__(self, config, model_name):
super().__init__()

config_model = config["model"]
self.config_training = config["training"]

self.enable_amp = config["training"]["mixed_precision"]
self.pooling_mode = Pooling(config_model["pooling_mode"])
Expand All @@ -134,6 +139,72 @@ def __init__(self, config, model_name):
* ff_config["hidden_layer_multiplier"]
)

def training_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self(**batch)
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("train_accu", outputs["accu"], sync_dist=True)
return outputs

def training_epoch_end(self, outputs):
logs = self.eval_epoch_end(outputs)
self.log("train_accu_mean", logs["accu"], sync_dist=True)

def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.config_training["learning_rate"],
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=self.config_training["weight_decay"],
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.config_training["learning_rate"],
pct_start=self.config_training["warmup"]
/ self.config_training["num_train_steps"],
anneal_strategy=self.config_training["lr_decay"],
total_steps=self.config_training["num_train_steps"],
)

return [optimizer], [lr_scheduler]

def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
outputs = self(**batch)
return outputs

def eval_epoch_end(self, outputs, prefix: str = "train"):
logs = {}
counts = torch.tensor([x["count"] for x in outputs]).float()
logs["count"] = counts.sum()
for k in ("accu", "loss"):
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
"count"
]
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
return logs

def validation_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self.eval_step(batch, batch_idx)
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
return outputs

def validation_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="val")

def test_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
return self.eval_step(batch, batch_idx)

def test_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="test")


class ModelForSC(ModelTrunk):
def __init__(self, config, model_name):
Expand All @@ -146,25 +217,26 @@ def __init__(self, config, model_name):
dim_mlp=self.dim_mlp,
)

def forward(self, input_ids_0, mask_0, label):
def forward( # type: ignore
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
):

with torch.cuda.amp.autocast(enabled=self.enable_amp):
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)

token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)
token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)

seq_scores = self.seq_classifer(token_out)
seq_scores = self.seq_classifer(token_out)

seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}

return outputs

Expand All @@ -180,31 +252,37 @@ def __init__(self, config, model_name):
dim_mlp=self.dim_mlp,
)

def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label):

with torch.cuda.amp.autocast(enabled=self.enable_amp):
mask_0, mask_1 = mask_0.long(), mask_1.long()

if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)

# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)

tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)

seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))

seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
def forward( # type: ignore
self,
input_ids_0: torch.Tensor,
input_ids_1: torch.Tensor,
mask_0: torch.Tensor,
mask_1: torch.Tensor,
label: torch.Tensor,
):

mask_0, mask_1 = mask_0.long(), mask_1.long()

if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)

# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)

tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)

seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))

seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}

return outputs
Loading

0 comments on commit 769cfe3

Please sign in to comment.