Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T104292598] Refactor the "LRA" training code -> Pytorch Lightning #343

Merged
merged 10 commits into from
Jul 8, 2022
44 changes: 13 additions & 31 deletions xformers/benchmarks/LRA/batch_fetch_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@
from pathlib import Path
from typing import Any, Dict

reference_steps = {
"image": 35176,
"listops": 10000,
"pathfinder32-curv_contour_length_14": 62400,
"pathfinder32-curv_baseline": 62400,
"pathfinder32-curv_contour_length_9": 62400,
"text": 20000,
"retrieval": 30000,
}

if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser(
Expand All @@ -38,36 +28,28 @@

for attention in filter(lambda x: x.is_dir(), root.iterdir()):
logging.info(f"\nFound results for {attention.stem}")
task_logs = attention.glob("*/*.log")
task_jsons = attention.glob("*/test_eval_summary.json")
results[attention.stem] = {}

for task in filter(lambda x: "__0" in str(x), task_logs):
for task in task_jsons:
task_name = task.stem.split("__")[0]
logging.info(f"Logs found for task: {task_name}")
results[attention.stem][task_name] = -1
found_result = False

# - collect the individual results
with open(task, "r") as result_file:
for line in reversed(result_file.readlines()):
if '"component": "test"' in line:
found_result = True

# Check that all the steps are done
res = json.loads(line)

if res["train_step_idx"] == reference_steps[task_name]:
results[attention.stem][task_name] = res["best_accu"]
logging.info(
f"Final result found for {task_name}: {results[attention.stem][task_name]}"
)
else:
logging.info(
"Current step: {}/{}. Not finished".format(
res["train_step_idx"], reference_steps[task_name]
)
)
break
dct = json.load(result_file)
if "test_accu_mean" in dct:
found_result = True
results[attention.stem][task_name] = dct["test_accu_mean"]

logging.info(
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
f"{results[attention.stem][task_name]}"
)
else:
break

# - report an error if no result was found
if not found_result:
Expand Down
9 changes: 0 additions & 9 deletions xformers/benchmarks/LRA/batch_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ def get_default_shared_folder() -> str:
parser.add_argument(
"--partition", default="a100", type=str, help="Partition where to submit"
)
parser.add_argument(
"-tb",
"--tb_path",
type=str,
help="Path to the tensorboard directory",
dest="tb_dir",
default=f"/{default_checkpoint_path}/{os.getenv('USER')}/xformers/tb",
)
args = parser.parse_args()

for attention in args.attentions:
Expand All @@ -54,5 +46,4 @@ def get_default_shared_folder() -> str:
+ f" --attention {attention} --task {task} --config {args.config_path}"
+ f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}"
+ f" --partition {args.partition}"
+ f" --tb_dir {args.tb_dir}/{attention}/{task}"
)
162 changes: 120 additions & 42 deletions xformers/benchmarks/LRA/code/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
# https://github.com/mlpen/Nystromformer

from enum import Enum
from typing import Dict, Union

import pytorch_lightning as pl
import torch
import torch.nn as nn

Expand All @@ -17,6 +19,8 @@
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
from xformers.utils import generate_matching_config

PLOutput = Dict[str, Union[float, torch.Tensor]]


class Pooling(str, Enum):
MEAN = "mean"
Expand Down Expand Up @@ -113,11 +117,12 @@ def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
return seq_score


class ModelTrunk(nn.Module):
class ModelTrunk(pl.LightningModule):
def __init__(self, config, model_name):
super().__init__()

config_model = config["model"]
self.config_training = config["training"]

self.enable_amp = config["training"]["mixed_precision"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is no longer being used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little buried, but it's being used in configure_optimizers.

self.pooling_mode = Pooling(config_model["pooling_mode"])
Expand All @@ -134,6 +139,72 @@ def __init__(self, config, model_name):
* ff_config["hidden_layer_multiplier"]
)

def training_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self(**batch)
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("train_accu", outputs["accu"], sync_dist=True)
return outputs

def training_epoch_end(self, outputs):
logs = self.eval_epoch_end(outputs)
self.log("train_accu_mean", logs["accu"], sync_dist=True)

def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.config_training["learning_rate"],
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=self.config_training["weight_decay"],
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.config_training["learning_rate"],
pct_start=self.config_training["warmup"]
/ self.config_training["num_train_steps"],
anneal_strategy=self.config_training["lr_decay"],
total_steps=self.config_training["num_train_steps"],
)

return [optimizer], [lr_scheduler]

def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
outputs = self(**batch)
return outputs

def eval_epoch_end(self, outputs, prefix: str = "train"):
logs = {}
counts = torch.tensor([x["count"] for x in outputs]).float()
logs["count"] = counts.sum()
for k in ("accu", "loss"):
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
"count"
]
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
return logs

def validation_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self.eval_step(batch, batch_idx)
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
return outputs

def validation_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="val")

def test_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
return self.eval_step(batch, batch_idx)

def test_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="test")


class ModelForSC(ModelTrunk):
def __init__(self, config, model_name):
Expand All @@ -146,25 +217,26 @@ def __init__(self, config, model_name):
dim_mlp=self.dim_mlp,
)

def forward(self, input_ids_0, mask_0, label):
def forward( # type: ignore
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
):

with torch.cuda.amp.autocast(enabled=self.enable_amp):
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)

token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)
token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)

seq_scores = self.seq_classifer(token_out)
seq_scores = self.seq_classifer(token_out)

seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}

return outputs

Expand All @@ -180,31 +252,37 @@ def __init__(self, config, model_name):
dim_mlp=self.dim_mlp,
)

def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label):

with torch.cuda.amp.autocast(enabled=self.enable_amp):
mask_0, mask_1 = mask_0.long(), mask_1.long()

if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)

# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)

tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)

seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))

seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
def forward( # type: ignore
self,
input_ids_0: torch.Tensor,
input_ids_1: torch.Tensor,
mask_0: torch.Tensor,
mask_1: torch.Tensor,
label: torch.Tensor,
):

mask_0, mask_1 = mask_0.long(), mask_1.long()

if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)

# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)

tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)

seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))

seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}

return outputs
Loading