From 328d54c6bd7813346196a3d00d4231e62098adf7 Mon Sep 17 00:00:00 2001 From: lisjin Date: Mon, 27 Jun 2022 16:05:31 -0700 Subject: [PATCH 1/9] First attempt at PL trainer --- xformers/benchmarks/LRA/code/model_wrapper.py | 124 +++++--- xformers/benchmarks/LRA/run_tasks_pl.py | 273 ++++++++++++++++++ xformers/components/multi_head_dispatch.py | 3 + 3 files changed, 366 insertions(+), 34 deletions(-) create mode 100644 xformers/benchmarks/LRA/run_tasks_pl.py diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index cf0be64e18..417e997991 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -9,6 +9,7 @@ from enum import Enum +import pytorch_lightning as pl import torch import torch.nn as nn @@ -113,11 +114,12 @@ def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor): return seq_score -class ModelTrunk(nn.Module): +class ModelTrunk(pl.LightningModule): def __init__(self, config, model_name): super().__init__() config_model = config["model"] + self.config_training=config["training"] self.enable_amp = config["training"]["mixed_precision"] self.pooling_mode = Pooling(config_model["pooling_mode"]) @@ -134,6 +136,62 @@ def __init__(self, config, model_name): * ff_config["hidden_layer_multiplier"] ) + def training_step(self, batch, batch_idx): + outputs = self(**batch) + return outputs + + def training_epoch_end(self, outputs): + logs = self.eval_epoch_end(outputs) + return {f"train_{k}": v for k, v in logs.items()} + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.config_training["learning_rate"], + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=self.config_training["weight_decay"], + ) + + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=self.config_training["learning_rate"], + pct_start=self.config_training["warmup"] / self.config_training["num_train_steps"], + anneal_strategy=self.config_training["lr_decay"], + total_steps=self.config_training["num_train_steps"], + ) + + return [optimizer], [lr_scheduler] + + def eval_step(self, batch, batch_idx): + outputs = self(**batch) + return outputs + + def eval_epoch_end(self, outputs): + logs = {} + counts = torch.tensor([x["count"] for x in outputs]).float() + logs["count"] = counts.sum() + for key in ("accu", "loss"): + logs[key] = (torch.tensor([x[key] for x in outputs]) * counts).sum() / logs["count"] + self.logger.log_metrics(logs) + return logs + + def validation_step(self, batch, batch_idx): + outputs = self.eval_step(batch, batch_idx) + self.log("val_accu", outputs["accu"], sync_dist=True) + return outputs + + def validation_epoch_end(self, outputs): + logs = self.eval_epoch_end(outputs) + return {f"val_{k}": v for k, v in logs.items()} + + def test_step(self, batch, batch_idx): + return self.eval_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + logs = self.eval_epoch_end(outputs) + return {f"test_{k}": v for k, v in logs.items()} + class ModelForSC(ModelTrunk): def __init__(self, config, model_name): @@ -148,23 +206,22 @@ def __init__(self, config, model_name): def forward(self, input_ids_0, mask_0, label): - with torch.cuda.amp.autocast(enabled=self.enable_amp): - if self.pooling_mode == Pooling.CLS: - input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) - token_out = self.norm( - self.model(input_ids_0, encoder_input_mask=mask_0) - ) * mask_0.unsqueeze(-1) + token_out = self.norm( + self.model(input_ids_0, encoder_input_mask=mask_0) + ) * mask_0.unsqueeze(-1) - seq_scores = self.seq_classifer(token_out) + seq_scores = self.seq_classifer(token_out) - seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) - seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) - outputs = { - "loss": seq_loss.mean(), - "accu": seq_accu.mean(), - "count": label.size(0), - } + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } return outputs @@ -182,29 +239,28 @@ def __init__(self, config, model_name): def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label): - with torch.cuda.amp.autocast(enabled=self.enable_amp): - mask_0, mask_1 = mask_0.long(), mask_1.long() + mask_0, mask_1 = mask_0.long(), mask_1.long() - if self.pooling_mode == Pooling.CLS: - input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) - input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) - # Concatenate the two inputs into one batch - input_ids = torch.cat([input_ids_0, input_ids_1], dim=0) - masks = torch.cat([mask_0, mask_1], dim=0) + # Concatenate the two inputs into one batch + input_ids = torch.cat([input_ids_0, input_ids_1], dim=0) + masks = torch.cat([mask_0, mask_1], dim=0) - tokens_out = self.norm( - self.model(input_ids, encoder_input_mask=masks) - ) * masks.unsqueeze(-1) + tokens_out = self.norm( + self.model(input_ids, encoder_input_mask=masks) + ) * masks.unsqueeze(-1) - seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0)) + seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0)) - seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) - seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) - outputs = { - "loss": seq_loss.mean(), - "accu": seq_accu.mean(), - "count": label.size(0), - } + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } return outputs diff --git a/xformers/benchmarks/LRA/run_tasks_pl.py b/xformers/benchmarks/LRA/run_tasks_pl.py new file mode 100644 index 0000000000..24bea7e3f4 --- /dev/null +++ b/xformers/benchmarks/LRA/run_tasks_pl.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: adapted from the Nystromformer repo +# https://github.com/mlpen/Nystromformer + +import argparse +import datetime +import json +import logging +import math +import os +import random +import sys +import time +from contextlib import suppress +from enum import Enum +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributed as dist +import torch.nn as nn +from fvcore.nn import FlopCountAnalysis, flop_count_str +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +from xformers.benchmarks.LRA.code.dataset import LRADataset +from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual +from xformers.benchmarks.utils import temp_files_ctx +from xformers.components.attention import ATTENTION_REGISTRY + + +class Task(str, Enum): + Retrieval = "retrieval" + ListOps = "listops" + Image = "image" + PathfinderBaseline = "pathfinder32-curv_baseline" + PathfinderContour9 = "pathfinder32-curv_contour_length_9" + PathfinderContour14 = "pathfinder32-curv_contour_length_14" + Text = "text" + + +def load_config(path: str) -> Dict: + with open(Path(path).absolute(), "r") as fileio: + config = json.load(fileio) + + # Duplicate the pathfinder configs + config["pathfinder32-curv_baseline"] = config["pathfinder32"] + config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"] + config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"] + return config + + +def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: + task = args.task + attention_name = args.attention + + if task == Task.Retrieval: + model: nn.Module = ModelForSCDual(config[f"{task}"], attention_name) + else: + model = ModelForSC(config[f"{task}"], attention_name) + + logging.info(model) + logging.info( + f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()]) // 1e3 / 1e3}M" + ) + + with torch.no_grad(): + # Check the flops + seq_len = config[f"{task}"]["model"]["common"]["seq_len"] + x = torch.rand(1, seq_len).long() + mask = torch.rand(1, seq_len).long() + indices = torch.rand(1, seq_len).long() + flops = FlopCountAnalysis(model.model, (x, mask, indices)) + logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") + logging.info(flop_count_str(flops)) + + return model + + +def get_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--attention", + type=str, + help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \ + A list can be passed to test several mechanisms in sequence", + dest="attention", + required=True, + ) + parser.add_argument( + "--task", + type=Task, + help=f"Task to chose, among {[t.value for t in Task]}.", + dest="task", + required=True, + ) + parser.add_argument( + "--skip_train", + type=bool, + help="Whether to skip training, and test an existing model", + dest="skip_train", + default=False, + ) + parser.add_argument( + "--config", + type=str, + help="Path to the config being used", + dest="config", + default="./config.json", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the checkpoint directory", + dest="checkpoint_dir", + default=f"/checkpoints/{os.getenv('USER')}/xformers", + ) + parser.add_argument( + "--debug", + help="Make it easier to debug a possible issue", + dest="debug", + default=False, + action="store_true", + ) + parser.add_argument( + "--world_size", + help="Number of GPUs used", + dest="world_size", + type=int, + default=1, + ) + parser.add_argument( + "--sweep_parameters", + help="Rewrite some hyperparameters in the config", + dest="sweep_parameters", + type=dict, + default=None, + ) + return parser + + +def setup_log(args, rank, attention_name, task): + logger = TensorBoardLogger( + save_dir=args.checkpoint_dir, + version=f"{task}__{attention_name}__{rank}" + ) + log_dir = os.path.join(logger._save_dir, logger._version) + return log_dir, logger + + +def rewrite_hyper(config, rewrites): + def replace(config_dict, k, v): + if len(k.split(":")) == 1: + config_dict[k] = v + return + first_key = k.split(":")[0] + assert first_key in config_dict, first_key + k = k[len(first_key) + 1 :] + replace(config_dict[first_key], k, v) + + for k, v in rewrites.items(): + replace(config, k, v) + return config + + +def build_dataloaders(args: argparse.Namespace, config_training: Dict): + datasets = {} + for component in ("train", "dev", "test"): + datasets[component] = LRADataset( + file_path=f"datasets/{args.task}.{component}.pickle", + seq_len=config_training["seq_len"], + ) + + # Gradient accumulation + accumu_steps = config_training["gradient_accumulation"] + logging.info(f"accumu_steps={accumu_steps}") + + # Batch size + per_gpu_batch_size = ( + config_training["batch_size"] // args.world_size // accumu_steps + ) + logging.warning( + f"Requested batch size: {config_training['batch_size']}. Given world size and grad accumulation, per-gpu batch is {per_gpu_batch_size}" + ) + + # Training epochs + if accumu_steps > 1: + config_training["num_train_steps"] *= accumu_steps + config_training["num_eval_steps"] *= accumu_steps + args.epochs = math.ceil( + config_training["num_train_steps"] + * config_training["batch_size"] + / len(datasets["train"]) + ) + logging.warning( + "Requested train steps: {config_training['num_train_steps']}. Given dataset, this translates into {args.epochs} epochs." + ) + + dataloaders = { + k: DataLoader( + v, + batch_size=per_gpu_batch_size, + shuffle=False, + pin_memory=True, + ) + for k, v in datasets.items() + } + return dataloaders + + +def benchmark(args): + log_dir, logger = setup_log(args, "main", f"{args.attention}", f"{args.task}") + args.logger = logger + + config = load_config(args.config) + + config_task = config[f"{args.task}"] + if args.sweep_parameters is not None: + logging.info("Replacing hyperparameters") + rewrite_hyper(config_task, args.sweep_parameters) + + config_training = config_task["training"] + config_training["seq_len"] = config_task["model"]["common"]["seq_len"] + logging.info(f"Learning rate: {config_training['learning_rate']}") + + pl.seed_everything(config_training.get("seed", 0)) + dataloaders = build_dataloaders(args, config_training) + + model = build_model(args, config) + + checkpoint_callback = ModelCheckpoint( + monitor="val_accu", + dirpath=args.checkpoint_dir, + filename=logger._version, + every_n_train_steps=config_training["eval_frequency"], + ) + ckpt_path = os.path.join(args.checkpoint_dir, f"{logger._version}.ckpt") + + trainer = pl.Trainer( + #accelerator="ddp", + accumulate_grad_batches=config_training["gradient_accumulation"], + callbacks=[checkpoint_callback], + gpus=args.world_size, + max_steps=config_training["num_train_steps"], + precision=16 if config_training["mixed_precision"] else 32, + ) + + if not args.skip_train: + trainer.fit( + model, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["dev"], + ) + trainer.validate( + model, + dataloaders=dataloaders["test"], + ckpt_path=ckpt_path, + ) + + +if __name__ == "__main__": + parser = get_arg_parser() + args = parser.parse_args() + benchmark(args) diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index 57413f3f92..f3eaffc915 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -31,6 +31,9 @@ class MultiHeadDispatchConfig: use_rotary_embeddings: Optional[bool] out_proj: Optional[nn.Module] + def __getitem__(self, item): + return getattr(self, item) + # Move head forward and fold into batch dim. dimensions become (B * nh, S, hs) def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): From 9d1dc808ab4a1515b8fddf52f075dac4a8b727af Mon Sep 17 00:00:00 2001 From: lisjin Date: Mon, 27 Jun 2022 13:21:34 -0400 Subject: [PATCH 2/9] Blocksparse switch revisions (#342) * minor cleanup; updated changelog * fixed mypy error * added checking for blocksparse availability Co-authored-by: Chris Yuan Co-authored-by: Chris Yuan --- BENCHMARKS.md | 2 +- CHANGELOG.md | 1 + tests/test_core_attention.py | 21 ++++++- .../benchmark_causal_blocksparse.py | 2 + xformers/benchmarks/utils.py | 6 +- xformers/components/attention/core.py | 59 +++++++++++-------- xformers/components/attention/utils.py | 10 ---- 7 files changed, 62 insertions(+), 39 deletions(-) diff --git a/BENCHMARKS.md b/BENCHMARKS.md index ef04add59d..65101d8459 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -129,7 +129,7 @@ __Some results:__ _Note_: The estimated flops currently miss accounting for many operators, and are almost certainly an undercount. See issue [#154](https://github.com/fairinternal/xformers/issues/154) -## Casual Attention Blocksparse Optimization +## Causal Attention Blocksparse Optimization FP16 | FP32 :-------------------------:|:-------------------------: diff --git a/CHANGELOG.md b/CHANGELOG.md index df328734bd..fe2c87b01d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support several initialization options [#312] - Conv2DFeedforward feedforward part [#321] - VisualAttention [#329] +- Automatic blocksparse for causal attention [#334] ## [0.0.11] - 2022-05-30 diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index b1cc306fed..26c71bb05e 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -7,10 +7,18 @@ import torch from torch import nn +from xformers import _is_triton_available from xformers.components.attention._sputnik_sparse import SparseCS from xformers.components.attention.attention_mask import AttentionMask from xformers.components.attention.core import scaled_dot_product_attention +if _is_triton_available: + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available and not gpu_capabilities_older_than_70() +) + _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] @@ -112,6 +120,9 @@ def test_amp_attention_sparsecs(device): assert r.dtype == expected_device +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("data_type", [torch.float16, torch.float32]) def test_switch_blocksparse(device, data_type): @@ -138,9 +149,14 @@ def test_switch_blocksparse(device, data_type): assert r_sparse.dtype == expected_device if r_custom.dtype == r_att_mask.dtype: - assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-3) + assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-2) + else: # r_custom fp16, r_att_mask fp32 + assert torch.allclose(r_custom, r_att_mask.half(), atol=1e-6, rtol=1e-2) +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) def test_switch_blocksparse_dims(device): b, s, d, nh = 8, 128, 32, 8 @@ -159,6 +175,9 @@ def test_switch_blocksparse_dims(device): assert r.dtype == expected_device +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("training", [True, False]) @pytest.mark.parametrize("drop_prob", [0.0, 0.3]) diff --git a/xformers/benchmarks/benchmark_causal_blocksparse.py b/xformers/benchmarks/benchmark_causal_blocksparse.py index 05d630ba91..70a5e499b9 100644 --- a/xformers/benchmarks/benchmark_causal_blocksparse.py +++ b/xformers/benchmarks/benchmark_causal_blocksparse.py @@ -122,12 +122,14 @@ def sdp_attention(): title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}", units="runtime in ms", dash_key="torch", + legend_loc="upper left", ) pretty_plot( results_mem, title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}", units="peak memory usage in MB", dash_key="torch", + legend_loc="upper left", ) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index 8b5838fd44..37c3be02b2 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -55,7 +55,9 @@ def pretty_print(results, title, units): print("") -def pretty_plot(results, title, units: str, filename=None, dash_key=""): +def pretty_plot( + results, title, units: str, filename=None, dash_key="", legend_loc="bottom_right" +): """Graph out the contents of a dict. Dash key means that if the result label has this key, then it will be displayed with a dash""" @@ -86,7 +88,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""): plt.plot(list(results.keys()), v) plt.title(title) - plt.legend(list(workloads.keys()), loc="lower right") + plt.legend(list(workloads.keys()), loc=legend_loc) plt.ylabel(units) plt.xticks(rotation=45) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 0dd457466c..9c8e47404b 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging import math from contextlib import nullcontext from functools import lru_cache @@ -13,14 +14,20 @@ from xformers import _is_sparse_available, _is_triton_available from xformers.components.attention.attention_mask import AttentionMask -from xformers.components.attention.blocksparse import BlockSparseAttention -from xformers.components.attention.utils import reshape_heads if _is_sparse_available: from ._sputnik_sparse import SparseCS if _is_triton_available: from xformers.triton.softmax import softmax as triton_softmax + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available and not gpu_capabilities_older_than_70() +) + +if _is_blocksparse_available: + from xformers.components.attention.blocksparse import BlockSparseAttention def _create_random_sparsity(matrix, sparsity, divisible_by=4): @@ -215,17 +222,19 @@ def scaled_query_key_softmax( return att -# 128 is default maxsize -@lru_cache(maxsize=128) -def _retrieve_blocksparse( - num_heads: int, seq_len: int, block_size: int -) -> BlockSparseAttention: - # Checks if blocksparse object exists in cache +if _is_blocksparse_available: + # 128 is default maxsize + @lru_cache(maxsize=128) + def _retrieve_blocksparse( + num_heads: int, seq_len: int, block_size: int + ) -> BlockSparseAttention: + # Checks if blocksparse object exists in cache - blocks = seq_len // block_size - print("Made uncached blocksparse") - layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) - return BlockSparseAttention(layout=layout_fill, block_size=block_size, causal=True) + blocks = seq_len // block_size + layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) + return BlockSparseAttention( + layout=layout_fill, block_size=block_size, causal=True + ) def blocksparse_attention( @@ -266,7 +275,7 @@ def blocksparse_attention( # Reshape attention (B, nh, S, hs) back to (N, S, hs) if orig_dim == 3: - return reshape_heads(att, *att.size()) + return att.flatten(0, 1) return att @@ -276,31 +285,31 @@ def scaled_dot_product_attention( v: torch.Tensor, att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], dropout: Optional[torch.nn.Module] = None, - block_size=128, + block_size: int = 128, ) -> torch.Tensor: autocast_disabled = ( _is_sparse_available and isinstance(att_mask, SparseCS) or (att_mask is not None and att_mask.is_sparse) ) + seq_len = q.shape[-2] - # Check if causal is required but mask is not sparse; if fp16 or under amp context + # switch if: + # causal is required but mask is not sparse + # fp16 or under amp context + # sequence length is divisible by block size + # same seq len for K and Q switch_to_blocksparse = ( - _is_triton_available + _is_blocksparse_available and (att_mask is not None and not att_mask.is_sparse) and (isinstance(att_mask, AttentionMask) and att_mask.is_causal) and (q.dtype == torch.float16 or torch.is_autocast_enabled()) - ) - - # Switch only if sequence length is divisible by block size - # Blocksparse requires the same dimensions for K and Q for now - seq_len = q.shape[-2] - if ( - switch_to_blocksparse and not seq_len % block_size and q.shape[-2] == k.shape[-2] - ): - # print("switching to blocksparse...") + ) + + if switch_to_blocksparse: + logging.info("Switching causal attention to Triton blocksparse...") return blocksparse_attention(q, k, v, dropout, block_size) with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): diff --git a/xformers/components/attention/utils.py b/xformers/components/attention/utils.py index 012188e1a2..d6bb06a1ac 100644 --- a/xformers/components/attention/utils.py +++ b/xformers/components/attention/utils.py @@ -106,13 +106,3 @@ def bool_mask_to_additive( mask_ = torch.zeros_like(mask, dtype=dtype) mask_[~mask] = float("-inf") return mask_ - - -# (B, S, D) to (B, S, nh, hs) -def split_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int): - return t.view(B, nH, S, Hs) - - -# (B, nh, S, hs) back to (N, S, hs) -def reshape_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int): - return t.view(B * nH, S, Hs) From d7b76b8c31d027a677f1713a94927ce615340cbe Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 28 Jun 2022 09:22:53 -0700 Subject: [PATCH 3/9] Finish PL refactor --- xformers/benchmarks/LRA/code/model_wrapper.py | 15 +- xformers/benchmarks/LRA/run_tasks.py | 587 +++++------------- xformers/benchmarks/LRA/run_tasks_pl.py | 273 -------- 3 files changed, 156 insertions(+), 719 deletions(-) delete mode 100644 xformers/benchmarks/LRA/run_tasks_pl.py diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index 417e997991..46d0a953b7 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -138,11 +138,13 @@ def __init__(self, config, model_name): def training_step(self, batch, batch_idx): outputs = self(**batch) + self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) + self.log("train_accu", outputs["accu"], sync_dist=True) return outputs def training_epoch_end(self, outputs): logs = self.eval_epoch_end(outputs) - return {f"train_{k}": v for k, v in logs.items()} + self.log(f"train_accu_mean", logs["accu"], sync_dist=True) def configure_optimizers(self): optimizer = torch.optim.AdamW( @@ -167,30 +169,29 @@ def eval_step(self, batch, batch_idx): outputs = self(**batch) return outputs - def eval_epoch_end(self, outputs): + def eval_epoch_end(self, outputs, prefix="train"): logs = {} counts = torch.tensor([x["count"] for x in outputs]).float() logs["count"] = counts.sum() for key in ("accu", "loss"): logs[key] = (torch.tensor([x[key] for x in outputs]) * counts).sum() / logs["count"] - self.logger.log_metrics(logs) + self.log(f"{prefix}_accu_mean", logs["accu"], sync_dist=True) return logs def validation_step(self, batch, batch_idx): outputs = self.eval_step(batch, batch_idx) + self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) self.log("val_accu", outputs["accu"], sync_dist=True) return outputs def validation_epoch_end(self, outputs): - logs = self.eval_epoch_end(outputs) - return {f"val_{k}": v for k, v in logs.items()} + self.eval_epoch_end(outputs, prefix="val") def test_step(self, batch, batch_idx): return self.eval_step(batch, batch_idx) def test_epoch_end(self, outputs): - logs = self.eval_epoch_end(outputs) - return {f"test_{k}": v for k, v in logs.items()} + self.eval_epoch_end(outputs, prefix="test") class ModelForSC(ModelTrunk): diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index 96c52357b8..22e703687b 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -18,14 +18,18 @@ import time from contextlib import suppress from enum import Enum +from glob import glob from pathlib import Path from typing import Any, Dict import numpy as np +import pytorch_lightning as pl import torch import torch.distributed as dist import torch.nn as nn from fvcore.nn import FlopCountAnalysis, flop_count_str +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader, DistributedSampler from torch.utils.tensorboard import SummaryWriter @@ -65,8 +69,8 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: else: model = ModelForSC(config[f"{task}"], attention_name) - args.logger.info(model) - args.logger.info( + logging.info(model) + logging.info( f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()]) // 1e3 / 1e3}M" ) @@ -77,433 +81,12 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: mask = torch.rand(1, seq_len).long() indices = torch.rand(1, seq_len).long() flops = FlopCountAnalysis(model.model, (x, mask, indices)) - args.logger.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") - args.logger.info(flop_count_str(flops)) + logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") + logging.info(flop_count_str(flops)) return model -def build_training_setup( - config_training: Dict, - task: Task, - model: nn.Module, - rank: int = 0, - world_size: int = 1, -): - datasets = {} - samplers = {} - - for component in ["train", "test", "dev"]: - dataset = LRADataset( - file_path=f"datasets/{task}.{component}.pickle", - seq_len=config_training["seq_len"], - ) - - sampler = DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=(component == "train"), - drop_last=(component == "train"), - ) # type:ignore - datasets[component] = dataset - samplers[component] = sampler - - logging.info(f"Learning rate: {config_training['learning_rate']}") - - optimizer = torch.optim.AdamW( - model.parameters(), - lr=config_training["learning_rate"], - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=config_training["weight_decay"], - ) - - lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( # type: ignore - optimizer=optimizer, - max_lr=config_training["learning_rate"], - pct_start=config_training["warmup"] / config_training["num_train_steps"], - anneal_strategy=config_training["lr_decay"], - total_steps=config_training["num_train_steps"], - ) - - amp_scaler = torch.cuda.amp.GradScaler(enabled=config_training["mixed_precision"]) - - logging.info(f"Dataloader ready. Rank {rank} of {world_size}") - - return datasets, samplers, optimizer, lr_scheduler, amp_scaler - - -def print_summary( - summary, - save_if_improved, - train_step_idx, - model, - checkpoint_path, - logger, - tb_logger=None, -): - - summary["loss"] = np.average(summary["loss"], weights=summary["count"]) - summary["accu"] = np.average(summary["accu"], weights=summary["count"]) - summary["count"] = np.sum(summary["count"]).astype(float) - - if summary["accu"] > summary["best_accu"]: - summary["best_accu"] = summary["accu"] - if save_if_improved: - best_accu = summary["best_accu"] - torch.save( - {"model_state_dict": model.state_dict()}, - checkpoint_path, - ) - logger.info(f"best_accu={best_accu:.3f}. Saved best model") - - summary["max_memory_mb"] = torch.cuda.max_memory_allocated() // 1e3 / 1e3 - - summary_round = {"train_step_idx": train_step_idx} - for key in summary: - if type(summary[key]) is str: - summary_round[key] = summary[key] - else: - summary_round[key] = round(summary[key], 4) - - if tb_logger: - tb_logger.add_scalar("acc", summary["accu"], train_step_idx) - tb_logger.add_scalar("loss", summary["loss"], train_step_idx) - tb_logger.add_scalar("max_mem", summary["max_memory_mb"], train_step_idx) - tb_logger.add_scalar("count", summary["count"], train_step_idx) - - logger.info(summary_round) - logger.info(json.dumps(summary_round, sort_keys=True) + "\n") - - summary["t"] = 0 - summary["loss"] = [] - summary["accu"] = [] - summary["count"] = [] - - -def setup_log(args, rank, attention_name, task): - log_f = Path( - os.path.join( - args.checkpoint_dir, f"{task}__{attention_name}__{rank}_output.log" - ) - ) - if not log_f.exists(): - log_f.parent.mkdir(parents=True, exist_ok=True) - with open(log_f, "x") as _: - pass - - logger = torch.multiprocessing.get_logger() - logger.setLevel(level=logging.INFO) - logger.addHandler(logging.FileHandler(filename=str(log_f))) - if rank == 0: - logger.addHandler(logging.StreamHandler(sys.stdout)) - return log_f.absolute(), logger - - -def eval_model(model, dataloaders, component, config, step): - model.eval() - - for dev_step_idx, batch_dev in enumerate(dataloaders[component]): - _ = step( - batch_dev, - component, - step_idx=dev_step_idx, - step_max=config["num_eval_steps"], - ) - - if dev_step_idx == config["num_eval_steps"]: - break - - model.train() - - -def rewrite_hyper(config, rewrites): - def replace(config_dict, k, v): - if len(k.split(":")) == 1: - config_dict[k] = v - return - first_key = k.split(":")[0] - assert first_key in config_dict, first_key - k = k[len(first_key) + 1 :] - replace(config_dict[first_key], k, v) - - for k, v in rewrites.items(): - replace(config, k, v) - return config - - -def seed_worker(_: int): - # Make sure that non-pytorch random generators are properly set - worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) - random.seed(worker_seed) - - -def benchmark(rank, args): - # Setup multiprocessing - dist.init_process_group( - init_method="file://" + args.temp_file, - backend="NCCL", - rank=rank, - world_size=args.world_size, - ) - try: - torch.cuda.set_device(args.gpu) - except AttributeError: - # Single node launcher - torch.cuda.set_device(rank) - - task = args.task - attention_name = args.attention - - # Build the problem - log_f_path, logger = setup_log(args, rank, attention_name, task) - args.logger = logger - config = load_config(args.config) - - config_task = config[f"{task}"] - if args.sweep_parameters is not None: - logger.info("Replacing hyperparameters") - rewrite_hyper(config_task, args.sweep_parameters) - - config_training = config_task["training"] - config_training["seq_len"] = config_task["model"]["common"]["seq_len"] - model = build_model(args, config) - - torch.manual_seed(config_training.get("seed", 0)) # also sets the cuda seed - np.random.seed(config_training.get("seed", 0)) - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.cuda.reset_peak_memory_stats() - - # tensorboard - tb_logger = SummaryWriter(args.tb_dir) - - torch.manual_seed(config_training.get("seed", 0)) # also sets the cuda seed - np.random.seed(config_training.get("seed", 0)) - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.cuda.reset_peak_memory_stats() - - # tensorboard - tb_logger = SummaryWriter(args.tb_dir) - - # Setup the training - device_ids = list(range(torch.cuda.device_count())) - logger.info(f"GPU list: {device_ids}") - model = model.cuda() - model = nn.parallel.DistributedDataParallel( - model, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True - ) - - ( - datasets, - samplers, - optimizer, - lr_scheduler, - amp_scaler, - ) = build_training_setup(config_training, task, model, rank, args.world_size) - - init_t = time.time() - - # Messenger structure which will be moved around to collect metrics - summary = { - comp: { - "t": 0, - "loss": [], - "accu": [], - "count": [], - "best_accu": 0, - "component": comp, - } - for comp in ["train", "dev", "test"] - } - - # Setup the dataloaders - accumu_steps = config_task["training"]["gradient_accumulation"] - per_gpu_batch_size = ( - config_training["batch_size"] // args.world_size // accumu_steps - ) - logging.warning( - "Requested batch size: {}. Given world size and grad accumulation, per-gpu batch is {}".format( - config_training["batch_size"], per_gpu_batch_size - ) - ) - - # reset train/eval steps if using gradient accumulation - if accumu_steps > 1: - config_training["num_train_steps"] *= accumu_steps - config_training["num_eval_steps"] *= accumu_steps - - epochs = math.ceil( - config_training["num_train_steps"] - * config_training["batch_size"] - / len(datasets["train"]) - ) - - logging.warning( - "Requested train steps: {}. Given dataset, this translates into {} epochs".format( - config_training["num_train_steps"], epochs - ) - ) - - logger.info(f"accumu_steps={accumu_steps}") - model_path = str(log_f_path).replace(".log", ".model") - g = torch.Generator() - g.manual_seed(config_training.get("seed", 0)) - - dataloaders = { - k: DataLoader( - datasets[k], - sampler=samplers[k], - batch_size=per_gpu_batch_size, - shuffle=False, - pin_memory=True, - num_workers=1, - worker_init_fn=seed_worker, - generator=g, - ) - for k in datasets.keys() - } - - # Our step function - def step( - batch: Dict[str, Any], - component: str, - step_idx: int, - step_max: int, - accumulate: bool = False, - ): - if step_idx > step_max: - logger.warning( - "Calling `step` beyond the training schedule, this is probably a mistake" - ) - return - - t0 = time.time() - batch_size = batch[list(batch.keys())[0]].size(0) - - for key in batch: - batch[key] = batch[key].cuda() - - if component == "train": - acc_context = model.no_sync() if accumulate else suppress() - - with acc_context, torch.autograd.set_detect_anomaly(args.debug): - outputs = model(**batch) - amp_scaler.scale(outputs["loss"]).backward() - - if not accumulate: - amp_scaler.step(optimizer) - optimizer.zero_grad() - amp_scaler.update() - lr_scheduler.step() - - else: - with torch.no_grad(): - outputs = model(**batch) - - t1 = time.time() - - t_escape = t1 - t0 - learning_rate = optimizer.param_groups[0]["lr"] - loss = outputs["loss"].item() - accu = outputs["accu"].item() - cnt = outputs["count"] - time_since_start = time.time() - init_t - eta = ( - datetime.timedelta( - seconds=round(time_since_start / (step_idx + 1) * step_max) - ) - if component == "train" - else -1 - ) - - if not step_idx % 10: - logger.info( - f"{component}: step={step_idx}/{step_max}, total_time={time_since_start:.1f}," - + f" eta={eta}," - + f" batch_time={t_escape:.3f}, bs={batch_size}, lr={learning_rate:.6f}," - + f" loss={loss:.4f}, accu={accu:.4f}", - ) - - summary[component]["t"] += t_escape - summary[component]["loss"].append(loss) - summary[component]["accu"].append(accu) - summary[component]["count"].append(cnt) - - if not accumulate: - step_idx += 1 - - return loss, step_idx - - # Start training or evaluating - train_step_idx = 0 - if not args.skip_train: - try: - model.train() - for epoch in range(epochs): - logger.info(f"\nEpoch {epoch}") - - # Make sure that per-rank sampling is really random - for sampler in samplers.values(): - sampler.set_epoch(epoch) - - for i_batch, batch in enumerate(dataloaders["train"]): - grad_accumulate = ( - i_batch % config_training["gradient_accumulation"] != 0 - ) - - _, train_step_idx = step( - batch, - component="train", - step_idx=train_step_idx, - step_max=config_training["num_train_steps"], - accumulate=grad_accumulate, - ) - - if not (train_step_idx + 1) % config_training["eval_frequency"]: - print_summary( - summary["train"], - False, - train_step_idx, - model, - model_path, - logger, - ) - - eval_model(model, dataloaders, "dev", config_training, step) - - print_summary( - summary["dev"], - True, - train_step_idx, - model, - model_path, - logger, - tb_logger, - ) - - if train_step_idx == config_training["num_train_steps"]: - break - - except KeyboardInterrupt as e: - print(e) - - checkpoint = torch.load(model_path, map_location="cpu") - model.load_state_dict(checkpoint["model_state_dict"]) - model.eval() - try: - eval_model(model, dataloaders, "test", config_training, step) - except StopIteration: - pass - - print_summary(summary["test"], False, train_step_idx, model, model_path, logger) - - def get_arg_parser(): parser = argparse.ArgumentParser() parser.add_argument( @@ -563,23 +146,149 @@ def get_arg_parser(): type=dict, default=None, ) - parser.add_argument( - "--tb_dir", - type=str, - help="Path to the tensorboard directory", - dest="tb_dir", - default=f"/checkpoints/{os.getenv('USER')}/xformers/tb", - ) return parser +def setup_log(args, rank, attention_name, task): + logger = TensorBoardLogger( + save_dir=args.checkpoint_dir, + name='', # remove lightning_logs subdirectory + version=f"{task}__{attention_name}__{rank}" + ) + log_dir = os.path.join(logger._save_dir, logger._version) + return log_dir, logger + + +def rewrite_hyper(config, rewrites): + def replace(config_dict, k, v): + if len(k.split(":")) == 1: + config_dict[k] = v + return + first_key = k.split(":")[0] + assert first_key in config_dict, first_key + k = k[len(first_key) + 1 :] + replace(config_dict[first_key], k, v) + + for k, v in rewrites.items(): + replace(config, k, v) + return config + + +def build_dataloaders( + args: argparse.Namespace, + config_training: Dict, + num_workers: int = 4, + ) -> Dict[str, DataLoader]: + datasets = {} + for component in ("train", "dev", "test"): + datasets[component] = LRADataset( + file_path=f"datasets/{args.task}.{component}.pickle", + seq_len=config_training["seq_len"], + ) + + # Gradient accumulation + accumu_steps = config_training["gradient_accumulation"] + logging.info(f"accumu_steps={accumu_steps}") + + # Batch size + per_gpu_batch_size = ( + config_training["batch_size"] // args.world_size // accumu_steps + ) + logging.warning( + f"Requested batch size: {config_training['batch_size']}. Given world size and grad accumulation, per-gpu batch is {per_gpu_batch_size}" + ) + + # Training epochs + if accumu_steps > 1: + config_training["num_train_steps"] *= accumu_steps + config_training["num_eval_steps"] *= accumu_steps + config_training["eval_frequency"] *= accumu_steps + args.epochs = math.ceil( + config_training["num_train_steps"] + * config_training["batch_size"] + / len(datasets["train"]) + ) + logging.warning( + "Requested train steps: {config_training['num_train_steps']}. Given dataset, this translates into {args.epochs} epochs." + ) + + dataloaders = { + k: DataLoader( + v, + batch_size=per_gpu_batch_size, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + ) + for k, v in datasets.items() + } + return dataloaders + + +def benchmark(args): + log_dir, logger = setup_log(args, "main", f"{args.attention}", f"{args.task}") + args.logger = logger + + config = load_config(args.config) + + config_task = config[f"{args.task}"] + if args.sweep_parameters is not None: + logging.info("Replacing hyperparameters") + rewrite_hyper(config_task, args.sweep_parameters) + + config_training = config_task["training"] + config_training["seq_len"] = config_task["model"]["common"]["seq_len"] + logging.info(f"Learning rate: {config_training['learning_rate']}") + + pl.seed_everything(config_training.get("seed", 0)) + dataloaders = build_dataloaders(args, config_training) + + model = build_model(args, config) + + checkpoint_callback = ModelCheckpoint( + monitor="val_accu", + mode="max", + dirpath=args.checkpoint_dir, + filename="{epoch}-{val_accu:.2f}", + every_n_train_steps=config_training["eval_frequency"], + ) + + # args.epochs is saved as side effect of build_dataloaders + steps_per_epoch = config_training["num_train_steps"] / args.epochs + + trainer = pl.Trainer( + accelerator="ddp", + accumulate_grad_batches=config_training["gradient_accumulation"], + callbacks=[checkpoint_callback], + detect_anomaly=args.debug, + deterministic=True, + gpus=args.world_size, + logger=logger, + max_steps=config_training["num_train_steps"], + precision=16 if config_training["mixed_precision"] else 32, + val_check_interval=config_training["eval_frequency"] / steps_per_epoch, + ) + + if not args.skip_train: + trainer.fit( + model, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["dev"], + ) + + ckpt_path = sorted( + glob(os.path.join(args.checkpoint_dir, f"*.{checkpoint_callback.FILE_EXTENSION}")), + key=os.path.getctime, + reverse=True + )[-1] + trainer.validate( + model, + dataloaders=dataloaders["test"], + ckpt_path=ckpt_path, + ) + + if __name__ == "__main__": parser = get_arg_parser() args = parser.parse_args() - setup_log(args, "main", f"{args.attention}", f"{args.task}") - - with temp_files_ctx(num=1) as temp_files: - args.temp_file = temp_files[0] - torch.multiprocessing.spawn( - benchmark, args=(args,), nprocs=args.world_size, join=True - ) + benchmark(args) diff --git a/xformers/benchmarks/LRA/run_tasks_pl.py b/xformers/benchmarks/LRA/run_tasks_pl.py deleted file mode 100644 index 24bea7e3f4..0000000000 --- a/xformers/benchmarks/LRA/run_tasks_pl.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -# CREDITS: adapted from the Nystromformer repo -# https://github.com/mlpen/Nystromformer - -import argparse -import datetime -import json -import logging -import math -import os -import random -import sys -import time -from contextlib import suppress -from enum import Enum -from pathlib import Path -from typing import Any, Dict - -import numpy as np -import pytorch_lightning as pl -import torch -import torch.distributed as dist -import torch.nn as nn -from fvcore.nn import FlopCountAnalysis, flop_count_str -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.tensorboard import SummaryWriter - -from xformers.benchmarks.LRA.code.dataset import LRADataset -from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual -from xformers.benchmarks.utils import temp_files_ctx -from xformers.components.attention import ATTENTION_REGISTRY - - -class Task(str, Enum): - Retrieval = "retrieval" - ListOps = "listops" - Image = "image" - PathfinderBaseline = "pathfinder32-curv_baseline" - PathfinderContour9 = "pathfinder32-curv_contour_length_9" - PathfinderContour14 = "pathfinder32-curv_contour_length_14" - Text = "text" - - -def load_config(path: str) -> Dict: - with open(Path(path).absolute(), "r") as fileio: - config = json.load(fileio) - - # Duplicate the pathfinder configs - config["pathfinder32-curv_baseline"] = config["pathfinder32"] - config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"] - config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"] - return config - - -def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: - task = args.task - attention_name = args.attention - - if task == Task.Retrieval: - model: nn.Module = ModelForSCDual(config[f"{task}"], attention_name) - else: - model = ModelForSC(config[f"{task}"], attention_name) - - logging.info(model) - logging.info( - f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()]) // 1e3 / 1e3}M" - ) - - with torch.no_grad(): - # Check the flops - seq_len = config[f"{task}"]["model"]["common"]["seq_len"] - x = torch.rand(1, seq_len).long() - mask = torch.rand(1, seq_len).long() - indices = torch.rand(1, seq_len).long() - flops = FlopCountAnalysis(model.model, (x, mask, indices)) - logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") - logging.info(flop_count_str(flops)) - - return model - - -def get_arg_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--attention", - type=str, - help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \ - A list can be passed to test several mechanisms in sequence", - dest="attention", - required=True, - ) - parser.add_argument( - "--task", - type=Task, - help=f"Task to chose, among {[t.value for t in Task]}.", - dest="task", - required=True, - ) - parser.add_argument( - "--skip_train", - type=bool, - help="Whether to skip training, and test an existing model", - dest="skip_train", - default=False, - ) - parser.add_argument( - "--config", - type=str, - help="Path to the config being used", - dest="config", - default="./config.json", - ) - parser.add_argument( - "--checkpoint_dir", - type=str, - help="Path to the checkpoint directory", - dest="checkpoint_dir", - default=f"/checkpoints/{os.getenv('USER')}/xformers", - ) - parser.add_argument( - "--debug", - help="Make it easier to debug a possible issue", - dest="debug", - default=False, - action="store_true", - ) - parser.add_argument( - "--world_size", - help="Number of GPUs used", - dest="world_size", - type=int, - default=1, - ) - parser.add_argument( - "--sweep_parameters", - help="Rewrite some hyperparameters in the config", - dest="sweep_parameters", - type=dict, - default=None, - ) - return parser - - -def setup_log(args, rank, attention_name, task): - logger = TensorBoardLogger( - save_dir=args.checkpoint_dir, - version=f"{task}__{attention_name}__{rank}" - ) - log_dir = os.path.join(logger._save_dir, logger._version) - return log_dir, logger - - -def rewrite_hyper(config, rewrites): - def replace(config_dict, k, v): - if len(k.split(":")) == 1: - config_dict[k] = v - return - first_key = k.split(":")[0] - assert first_key in config_dict, first_key - k = k[len(first_key) + 1 :] - replace(config_dict[first_key], k, v) - - for k, v in rewrites.items(): - replace(config, k, v) - return config - - -def build_dataloaders(args: argparse.Namespace, config_training: Dict): - datasets = {} - for component in ("train", "dev", "test"): - datasets[component] = LRADataset( - file_path=f"datasets/{args.task}.{component}.pickle", - seq_len=config_training["seq_len"], - ) - - # Gradient accumulation - accumu_steps = config_training["gradient_accumulation"] - logging.info(f"accumu_steps={accumu_steps}") - - # Batch size - per_gpu_batch_size = ( - config_training["batch_size"] // args.world_size // accumu_steps - ) - logging.warning( - f"Requested batch size: {config_training['batch_size']}. Given world size and grad accumulation, per-gpu batch is {per_gpu_batch_size}" - ) - - # Training epochs - if accumu_steps > 1: - config_training["num_train_steps"] *= accumu_steps - config_training["num_eval_steps"] *= accumu_steps - args.epochs = math.ceil( - config_training["num_train_steps"] - * config_training["batch_size"] - / len(datasets["train"]) - ) - logging.warning( - "Requested train steps: {config_training['num_train_steps']}. Given dataset, this translates into {args.epochs} epochs." - ) - - dataloaders = { - k: DataLoader( - v, - batch_size=per_gpu_batch_size, - shuffle=False, - pin_memory=True, - ) - for k, v in datasets.items() - } - return dataloaders - - -def benchmark(args): - log_dir, logger = setup_log(args, "main", f"{args.attention}", f"{args.task}") - args.logger = logger - - config = load_config(args.config) - - config_task = config[f"{args.task}"] - if args.sweep_parameters is not None: - logging.info("Replacing hyperparameters") - rewrite_hyper(config_task, args.sweep_parameters) - - config_training = config_task["training"] - config_training["seq_len"] = config_task["model"]["common"]["seq_len"] - logging.info(f"Learning rate: {config_training['learning_rate']}") - - pl.seed_everything(config_training.get("seed", 0)) - dataloaders = build_dataloaders(args, config_training) - - model = build_model(args, config) - - checkpoint_callback = ModelCheckpoint( - monitor="val_accu", - dirpath=args.checkpoint_dir, - filename=logger._version, - every_n_train_steps=config_training["eval_frequency"], - ) - ckpt_path = os.path.join(args.checkpoint_dir, f"{logger._version}.ckpt") - - trainer = pl.Trainer( - #accelerator="ddp", - accumulate_grad_batches=config_training["gradient_accumulation"], - callbacks=[checkpoint_callback], - gpus=args.world_size, - max_steps=config_training["num_train_steps"], - precision=16 if config_training["mixed_precision"] else 32, - ) - - if not args.skip_train: - trainer.fit( - model, - train_dataloaders=dataloaders["train"], - val_dataloaders=dataloaders["dev"], - ) - trainer.validate( - model, - dataloaders=dataloaders["test"], - ckpt_path=ckpt_path, - ) - - -if __name__ == "__main__": - parser = get_arg_parser() - args = parser.parse_args() - benchmark(args) From 3aca9bef96a3dad2ad38499d677ca03631bdf2c7 Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 28 Jun 2022 10:24:15 -0700 Subject: [PATCH 4/9] Fix coding style, remove unused imports --- xformers/benchmarks/LRA/code/model_wrapper.py | 9 +++-- xformers/benchmarks/LRA/run_tasks.py | 40 ++++++++----------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index 46d0a953b7..47cad94bbb 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -119,7 +119,7 @@ def __init__(self, config, model_name): super().__init__() config_model = config["model"] - self.config_training=config["training"] + self.config_training = config["training"] self.enable_amp = config["training"]["mixed_precision"] self.pooling_mode = Pooling(config_model["pooling_mode"]) @@ -158,7 +158,8 @@ def configure_optimizers(self): lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.config_training["learning_rate"], - pct_start=self.config_training["warmup"] / self.config_training["num_train_steps"], + pct_start=self.config_training["warmup"] + / self.config_training["num_train_steps"], anneal_strategy=self.config_training["lr_decay"], total_steps=self.config_training["num_train_steps"], ) @@ -174,7 +175,9 @@ def eval_epoch_end(self, outputs, prefix="train"): counts = torch.tensor([x["count"] for x in outputs]).float() logs["count"] = counts.sum() for key in ("accu", "loss"): - logs[key] = (torch.tensor([x[key] for x in outputs]) * counts).sum() / logs["count"] + logs[key] = (torch.tensor([x[key] for x in outputs]) * counts).sum() / logs[ + "count" + ] self.log(f"{prefix}_accu_mean", logs["accu"], sync_dist=True) return logs diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index 22e703687b..645b9606ca 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -4,38 +4,27 @@ # LICENSE file in the root directory of this source tree. -# CREDITS: adapted from the Nystromformer repo -# https://github.com/mlpen/Nystromformer - import argparse -import datetime import json import logging import math import os -import random -import sys -import time -from contextlib import suppress from enum import Enum from glob import glob from pathlib import Path -from typing import Any, Dict +from typing import Dict, Tuple import numpy as np import pytorch_lightning as pl import torch -import torch.distributed as dist import torch.nn as nn from fvcore.nn import FlopCountAnalysis, flop_count_str from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader from xformers.benchmarks.LRA.code.dataset import LRADataset from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual -from xformers.benchmarks.utils import temp_files_ctx from xformers.components.attention import ATTENTION_REGISTRY @@ -149,13 +138,14 @@ def get_arg_parser(): return parser -def setup_log(args, rank, attention_name, task): +def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]: + experiment_name = f"{task}__{attention_name}" logger = TensorBoardLogger( save_dir=args.checkpoint_dir, - name='', # remove lightning_logs subdirectory - version=f"{task}__{attention_name}__{rank}" + name="", # remove lightning_logs subdirectory + version=experiment_name, ) - log_dir = os.path.join(logger._save_dir, logger._version) + log_dir = os.path.join(logger._save_dir, experiment_name) return log_dir, logger @@ -175,10 +165,10 @@ def replace(config_dict, k, v): def build_dataloaders( - args: argparse.Namespace, - config_training: Dict, - num_workers: int = 4, - ) -> Dict[str, DataLoader]: + args: argparse.Namespace, + config_training: Dict, + num_workers: int = 4, +) -> Dict[str, DataLoader]: datasets = {} for component in ("train", "dev", "test"): datasets[component] = LRADataset( @@ -226,7 +216,7 @@ def build_dataloaders( def benchmark(args): - log_dir, logger = setup_log(args, "main", f"{args.attention}", f"{args.task}") + log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}") args.logger = logger config = load_config(args.config) @@ -277,9 +267,11 @@ def benchmark(args): ) ckpt_path = sorted( - glob(os.path.join(args.checkpoint_dir, f"*.{checkpoint_callback.FILE_EXTENSION}")), + glob( + os.path.join(args.checkpoint_dir, f"*.{checkpoint_callback.FILE_EXTENSION}") + ), key=os.path.getctime, - reverse=True + reverse=True, )[-1] trainer.validate( model, From 4257066bfcef88411e38376f42e04f314062d200 Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 28 Jun 2022 10:59:13 -0700 Subject: [PATCH 5/9] Fix flake8 error --- xformers/benchmarks/LRA/run_tasks.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index 645b9606ca..79bed4053c 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import Dict, Tuple -import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn @@ -24,7 +23,8 @@ from torch.utils.data import DataLoader from xformers.benchmarks.LRA.code.dataset import LRADataset -from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual +from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC,\ + ModelForSCDual from xformers.components.attention import ATTENTION_REGISTRY @@ -59,9 +59,8 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: model = ModelForSC(config[f"{task}"], attention_name) logging.info(model) - logging.info( - f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()]) // 1e3 / 1e3}M" - ) + summary = pl.utilities.model_summary.LayerSummary(model) + logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M") with torch.no_grad(): # Check the flops @@ -156,7 +155,7 @@ def replace(config_dict, k, v): return first_key = k.split(":")[0] assert first_key in config_dict, first_key - k = k[len(first_key) + 1 :] + k = k[len(first_key) + 1:] replace(config_dict[first_key], k, v) for k, v in rewrites.items(): @@ -185,7 +184,9 @@ def build_dataloaders( config_training["batch_size"] // args.world_size // accumu_steps ) logging.warning( - f"Requested batch size: {config_training['batch_size']}. Given world size and grad accumulation, per-gpu batch is {per_gpu_batch_size}" + f"Requested batch size: {config_training['batch_size']}. Given world\ + size and grad accumulation, per-gpu batch is\ + {per_gpu_batch_size}" ) # Training epochs @@ -199,7 +200,8 @@ def build_dataloaders( / len(datasets["train"]) ) logging.warning( - "Requested train steps: {config_training['num_train_steps']}. Given dataset, this translates into {args.epochs} epochs." + "Requested train steps: {config_training['num_train_steps']}. Given\ + dataset, this translates into {args.epochs} epochs." ) dataloaders = { @@ -268,7 +270,10 @@ def benchmark(args): ckpt_path = sorted( glob( - os.path.join(args.checkpoint_dir, f"*.{checkpoint_callback.FILE_EXTENSION}") + os.path.join( + args.checkpoint_dir, + f"*.{checkpoint_callback.FILE_EXTENSION}", + ) ), key=os.path.getctime, reverse=True, From bd7638ff45c3ab9414e0f21f2c14f43cc4f1bb7e Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 28 Jun 2022 12:06:56 -0700 Subject: [PATCH 6/9] Make isort happy --- xformers/benchmarks/LRA/run_tasks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index 79bed4053c..bce23f7940 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -23,8 +23,7 @@ from torch.utils.data import DataLoader from xformers.benchmarks.LRA.code.dataset import LRADataset -from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC,\ - ModelForSCDual +from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual from xformers.components.attention import ATTENTION_REGISTRY @@ -155,7 +154,7 @@ def replace(config_dict, k, v): return first_key = k.split(":")[0] assert first_key in config_dict, first_key - k = k[len(first_key) + 1:] + k = k[len(first_key) + 1 :] replace(config_dict[first_key], k, v) for k, v in rewrites.items(): From 1ae0aa808ddda83648f5c1703a07cc134353c6fd Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 28 Jun 2022 14:21:44 -0700 Subject: [PATCH 7/9] Let pre-commit handle formatting... --- xformers/benchmarks/LRA/code/model_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index 47cad94bbb..bd2f457966 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -144,7 +144,7 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs): logs = self.eval_epoch_end(outputs) - self.log(f"train_accu_mean", logs["accu"], sync_dist=True) + self.log("train_accu_mean", logs["accu"], sync_dist=True) def configure_optimizers(self): optimizer = torch.optim.AdamW( From cb4dfa0ea346ce374f5366483b863a836d705955 Mon Sep 17 00:00:00 2001 From: lisjin Date: Sat, 2 Jul 2022 19:04:52 -0700 Subject: [PATCH 8/9] Add type hints, fix eval behavior --- xformers/benchmarks/LRA/code/model_wrapper.py | 43 ++++++++++++++----- xformers/benchmarks/LRA/run_tasks.py | 38 ++++------------ 2 files changed, 42 insertions(+), 39 deletions(-) diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index bd2f457966..6760f11911 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -8,6 +8,7 @@ # https://github.com/mlpen/Nystromformer from enum import Enum +from typing import Dict, Union import pytorch_lightning as pl import torch @@ -18,6 +19,8 @@ from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig from xformers.utils import generate_matching_config +PLOutput = Dict[str, Union[float, torch.Tensor]] + class Pooling(str, Enum): MEAN = "mean" @@ -136,9 +139,16 @@ def __init__(self, config, model_name): * ff_config["hidden_layer_multiplier"] ) - def training_step(self, batch, batch_idx): + def get_progress_bar_dict(self): + return { + k: v for k, v in super().get_progress_bar_dict().items() if k != "v_num" + } + + def training_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: outputs = self(**batch) - self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) + self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore self.log("train_accu", outputs["accu"], sync_dist=True) return outputs @@ -166,11 +176,11 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] - def eval_step(self, batch, batch_idx): + def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput: outputs = self(**batch) return outputs - def eval_epoch_end(self, outputs, prefix="train"): + def eval_epoch_end(self, outputs, prefix: str = "train"): logs = {} counts = torch.tensor([x["count"] for x in outputs]).float() logs["count"] = counts.sum() @@ -181,16 +191,20 @@ def eval_epoch_end(self, outputs, prefix="train"): self.log(f"{prefix}_accu_mean", logs["accu"], sync_dist=True) return logs - def validation_step(self, batch, batch_idx): + def validation_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: outputs = self.eval_step(batch, batch_idx) - self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) - self.log("val_accu", outputs["accu"], sync_dist=True) + self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True) return outputs def validation_epoch_end(self, outputs): self.eval_epoch_end(outputs, prefix="val") - def test_step(self, batch, batch_idx): + def test_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: return self.eval_step(batch, batch_idx) def test_epoch_end(self, outputs): @@ -208,7 +222,9 @@ def __init__(self, config, model_name): dim_mlp=self.dim_mlp, ) - def forward(self, input_ids_0, mask_0, label): + def forward( # type: ignore + self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor + ): if self.pooling_mode == Pooling.CLS: input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) @@ -241,7 +257,14 @@ def __init__(self, config, model_name): dim_mlp=self.dim_mlp, ) - def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label): + def forward( # type: ignore + self, + input_ids_0: torch.Tensor, + input_ids_1: torch.Tensor, + mask_0: torch.Tensor, + mask_1: torch.Tensor, + label: torch.Tensor, + ): mask_0, mask_1 = mask_0.long(), mask_1.long() diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index bce23f7940..0a44ba4a84 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -7,10 +7,8 @@ import argparse import json import logging -import math import os from enum import Enum -from glob import glob from pathlib import Path from typing import Dict, Tuple @@ -20,6 +18,7 @@ from fvcore.nn import FlopCountAnalysis, flop_count_str from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.strategies import DDPStrategy from torch.utils.data import DataLoader from xformers.benchmarks.LRA.code.dataset import LRADataset @@ -184,8 +183,8 @@ def build_dataloaders( ) logging.warning( f"Requested batch size: {config_training['batch_size']}. Given world\ - size and grad accumulation, per-gpu batch is\ - {per_gpu_batch_size}" + size and grad accumulation, per-gpu batch is\ + {per_gpu_batch_size}" ) # Training epochs @@ -193,15 +192,6 @@ def build_dataloaders( config_training["num_train_steps"] *= accumu_steps config_training["num_eval_steps"] *= accumu_steps config_training["eval_frequency"] *= accumu_steps - args.epochs = math.ceil( - config_training["num_train_steps"] - * config_training["batch_size"] - / len(datasets["train"]) - ) - logging.warning( - "Requested train steps: {config_training['num_train_steps']}. Given\ - dataset, this translates into {args.epochs} epochs." - ) dataloaders = { k: DataLoader( @@ -244,20 +234,20 @@ def benchmark(args): every_n_train_steps=config_training["eval_frequency"], ) - # args.epochs is saved as side effect of build_dataloaders - steps_per_epoch = config_training["num_train_steps"] / args.epochs - trainer = pl.Trainer( - accelerator="ddp", + accelerator="gpu", + strategy=DDPStrategy(find_unused_parameters=args.debug), accumulate_grad_batches=config_training["gradient_accumulation"], callbacks=[checkpoint_callback], detect_anomaly=args.debug, deterministic=True, gpus=args.world_size, + limit_val_batches=config_training["num_eval_steps"], logger=logger, max_steps=config_training["num_train_steps"], precision=16 if config_training["mixed_precision"] else 32, - val_check_interval=config_training["eval_frequency"] / steps_per_epoch, + val_check_interval=config_training["eval_frequency"] + / float(len(dataloaders["train"])), ) if not args.skip_train: @@ -267,20 +257,10 @@ def benchmark(args): val_dataloaders=dataloaders["dev"], ) - ckpt_path = sorted( - glob( - os.path.join( - args.checkpoint_dir, - f"*.{checkpoint_callback.FILE_EXTENSION}", - ) - ), - key=os.path.getctime, - reverse=True, - )[-1] trainer.validate( model, dataloaders=dataloaders["test"], - ckpt_path=ckpt_path, + ckpt_path=checkpoint_callback.best_model_path, ) From 5d06c4734fcd4ac35b217dbd7bf2934c1fb87713 Mon Sep 17 00:00:00 2001 From: lisjin Date: Tue, 5 Jul 2022 11:31:57 -0700 Subject: [PATCH 9/9] Evaluate PL refactor with batch_submit.py --- .../benchmarks/LRA/batch_fetch_results.py | 44 +++++--------- xformers/benchmarks/LRA/batch_submit.py | 9 --- xformers/benchmarks/LRA/code/model_wrapper.py | 11 +--- xformers/benchmarks/LRA/run_tasks.py | 57 ++++++++++++++----- xformers/benchmarks/LRA/run_with_submitit.py | 2 +- xformers/factory/model_factory.py | 4 +- 6 files changed, 62 insertions(+), 65 deletions(-) diff --git a/xformers/benchmarks/LRA/batch_fetch_results.py b/xformers/benchmarks/LRA/batch_fetch_results.py index ccb99d0302..88227ac331 100644 --- a/xformers/benchmarks/LRA/batch_fetch_results.py +++ b/xformers/benchmarks/LRA/batch_fetch_results.py @@ -10,16 +10,6 @@ from pathlib import Path from typing import Any, Dict -reference_steps = { - "image": 35176, - "listops": 10000, - "pathfinder32-curv_contour_length_14": 62400, - "pathfinder32-curv_baseline": 62400, - "pathfinder32-curv_contour_length_9": 62400, - "text": 20000, - "retrieval": 30000, -} - if __name__ == "__main__": # Get the user requests parser = argparse.ArgumentParser( @@ -38,10 +28,10 @@ for attention in filter(lambda x: x.is_dir(), root.iterdir()): logging.info(f"\nFound results for {attention.stem}") - task_logs = attention.glob("*/*.log") + task_jsons = attention.glob("*/test_eval_summary.json") results[attention.stem] = {} - for task in filter(lambda x: "__0" in str(x), task_logs): + for task in task_jsons: task_name = task.stem.split("__")[0] logging.info(f"Logs found for task: {task_name}") results[attention.stem][task_name] = -1 @@ -49,25 +39,17 @@ # - collect the individual results with open(task, "r") as result_file: - for line in reversed(result_file.readlines()): - if '"component": "test"' in line: - found_result = True - - # Check that all the steps are done - res = json.loads(line) - - if res["train_step_idx"] == reference_steps[task_name]: - results[attention.stem][task_name] = res["best_accu"] - logging.info( - f"Final result found for {task_name}: {results[attention.stem][task_name]}" - ) - else: - logging.info( - "Current step: {}/{}. Not finished".format( - res["train_step_idx"], reference_steps[task_name] - ) - ) - break + dct = json.load(result_file) + if "test_accu_mean" in dct: + found_result = True + results[attention.stem][task_name] = dct["test_accu_mean"] + + logging.info( + f"Final result found for {task_name} at epoch {dct['train_step_idx']}: " + f"{results[attention.stem][task_name]}" + ) + else: + break # - report an error if no result was found if not found_result: diff --git a/xformers/benchmarks/LRA/batch_submit.py b/xformers/benchmarks/LRA/batch_submit.py index 964a516ce2..a3077aa62a 100644 --- a/xformers/benchmarks/LRA/batch_submit.py +++ b/xformers/benchmarks/LRA/batch_submit.py @@ -37,14 +37,6 @@ def get_default_shared_folder() -> str: parser.add_argument( "--partition", default="a100", type=str, help="Partition where to submit" ) - parser.add_argument( - "-tb", - "--tb_path", - type=str, - help="Path to the tensorboard directory", - dest="tb_dir", - default=f"/{default_checkpoint_path}/{os.getenv('USER')}/xformers/tb", - ) args = parser.parse_args() for attention in args.attentions: @@ -54,5 +46,4 @@ def get_default_shared_folder() -> str: + f" --attention {attention} --task {task} --config {args.config_path}" + f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}" + f" --partition {args.partition}" - + f" --tb_dir {args.tb_dir}/{attention}/{task}" ) diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index 6760f11911..5eb3e2ca74 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -139,11 +139,6 @@ def __init__(self, config, model_name): * ff_config["hidden_layer_multiplier"] ) - def get_progress_bar_dict(self): - return { - k: v for k, v in super().get_progress_bar_dict().items() if k != "v_num" - } - def training_step( # type: ignore self, batch: Dict[str, torch.Tensor], batch_idx: int ) -> PLOutput: @@ -184,11 +179,11 @@ def eval_epoch_end(self, outputs, prefix: str = "train"): logs = {} counts = torch.tensor([x["count"] for x in outputs]).float() logs["count"] = counts.sum() - for key in ("accu", "loss"): - logs[key] = (torch.tensor([x[key] for x in outputs]) * counts).sum() / logs[ + for k in ("accu", "loss"): + logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[ "count" ] - self.log(f"{prefix}_accu_mean", logs["accu"], sync_dist=True) + self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True) return logs def validation_step( # type: ignore diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index 0a44ba4a84..8f6e26edb7 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn from fvcore.nn import FlopCountAnalysis, flop_count_str -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.strategies import DDPStrategy from torch.utils.data import DataLoader @@ -51,10 +51,11 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: task = args.task attention_name = args.attention - if task == Task.Retrieval: - model: nn.Module = ModelForSCDual(config[f"{task}"], attention_name) - else: - model = ModelForSC(config[f"{task}"], attention_name) + model: pl.LightningModule = ( + ModelForSCDual(config[f"{task}"], attention_name) + if task == Task.Retrieval + else ModelForSC(config[f"{task}"], attention_name) + ) logging.info(model) summary = pl.utilities.model_summary.LayerSummary(model) @@ -111,6 +112,11 @@ def get_arg_parser(): dest="checkpoint_dir", default=f"/checkpoints/{os.getenv('USER')}/xformers", ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Path to checkpoint", + ) parser.add_argument( "--debug", help="Make it easier to debug a possible issue", @@ -187,12 +193,6 @@ def build_dataloaders( {per_gpu_batch_size}" ) - # Training epochs - if accumu_steps > 1: - config_training["num_train_steps"] *= accumu_steps - config_training["num_eval_steps"] *= accumu_steps - config_training["eval_frequency"] *= accumu_steps - dataloaders = { k: DataLoader( v, @@ -206,6 +206,20 @@ def build_dataloaders( return dataloaders +def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]: + eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step} + for k, v in trainer.callback_metrics.items(): + eval_summary[k] = v.item() + return eval_summary + + +class BasicProgressBar(TQDMProgressBar): + def get_metrics(self, trainer, model): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + def benchmark(args): log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}") args.logger = logger @@ -226,6 +240,7 @@ def benchmark(args): model = build_model(args, config) + progress_bar = BasicProgressBar() checkpoint_callback = ModelCheckpoint( monitor="val_accu", mode="max", @@ -236,15 +251,18 @@ def benchmark(args): trainer = pl.Trainer( accelerator="gpu", - strategy=DDPStrategy(find_unused_parameters=args.debug), + strategy=DDPStrategy(find_unused_parameters=args.debug) + if not args.skip_train + else None, accumulate_grad_batches=config_training["gradient_accumulation"], - callbacks=[checkpoint_callback], + callbacks=[progress_bar, checkpoint_callback], detect_anomaly=args.debug, deterministic=True, gpus=args.world_size, limit_val_batches=config_training["num_eval_steps"], logger=logger, max_steps=config_training["num_train_steps"], + num_sanity_val_steps=int(not args.skip_train), precision=16 if config_training["mixed_precision"] else 32, val_check_interval=config_training["eval_frequency"] / float(len(dataloaders["train"])), @@ -256,15 +274,24 @@ def benchmark(args): train_dataloaders=dataloaders["train"], val_dataloaders=dataloaders["dev"], ) + ckpt_path = checkpoint_callback.best_model_path + else: + ckpt_path = args.checkpoint_path - trainer.validate( + trainer.test( model, dataloaders=dataloaders["test"], - ckpt_path=checkpoint_callback.best_model_path, + ckpt_path=ckpt_path, ) + eval_summary = get_eval_summary(trainer) + with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f: + logging.info(f"Saving test results at {f.name}") + json.dump(eval_summary, f) if __name__ == "__main__": parser = get_arg_parser() args = parser.parse_args() + if args.skip_train and args.checkpoint_path is None: + raise parser.error("Must provide --checkpoint_path if --skip_train=True") benchmark(args) diff --git a/xformers/benchmarks/LRA/run_with_submitit.py b/xformers/benchmarks/LRA/run_with_submitit.py index d7af681e86..13945aac6c 100644 --- a/xformers/benchmarks/LRA/run_with_submitit.py +++ b/xformers/benchmarks/LRA/run_with_submitit.py @@ -76,7 +76,7 @@ def __init__(self, args): def __call__(self): self._setup_gpu_args() - benchmark(self.args.rank, self.args) + benchmark(self.args) def checkpoint(self): self.args.dist_url = get_init_file().as_uri() diff --git a/xformers/factory/model_factory.py b/xformers/factory/model_factory.py index 0b0d642a46..8b9e6bdcdc 100644 --- a/xformers/factory/model_factory.py +++ b/xformers/factory/model_factory.py @@ -277,7 +277,9 @@ def forward( # Apply the optional input masking if encoder_input_mask is not None: - x += encoder_input_mask.unsqueeze(0).unsqueeze(-1) + if x.dim() - encoder_input_mask.dim() > 1: + encoder_input_mask.unsqueeze(0) + x += encoder_input_mask.unsqueeze(-1) x = encoders(x) memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)