From 26fbdd9bffab444f2c8978cf6610a57895e22e12 Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 10:47:22 -0700 Subject: [PATCH 1/8] added four blocksparsity layouts Four sparsity layouts from DeepSpeed are now available for blocksparse attention on xFormer: Fixed BSLongformer BigBird Variable sparsity_configs.py (https://fburl.com/code/s2n7x8gs) contains flexible objects with many parameters. The default parameters can be invoked through the quick_ functionality in attention_patterns.py (https://fburl.com/code/hya0t9e7) the produced layouts can be turned into a pattern with layout_to_pattern function (https://fburl.com/code/1qmsntyj) Detailed notes for the task: https://docs.google.com/document/d/1cBlZeccvphI-d5avLkgKwZ4ScXQM1q6igqvlyFDXdDc/edit?usp=sharing --- .DS_Store | Bin 0 -> 8196 bytes tests/test_attention_patterns.py | 80 +++ xformers/.DS_Store | Bin 0 -> 6148 bytes xformers/components/.DS_Store | Bin 0 -> 6148 bytes .../attention/attention_patterns.py | 29 + .../components/attention/sparsity_config.py | 634 ++++++++++++++++++ 6 files changed, 743 insertions(+) create mode 100644 .DS_Store create mode 100644 xformers/.DS_Store create mode 100644 xformers/components/.DS_Store create mode 100644 xformers/components/attention/sparsity_config.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e4c136bf463067aefda298c5ee52d00322782785 GIT binary patch literal 8196 zcmeHM&2AGh5FWQF-6;G-sFf-Z(n?%Q3A7*~E=l?W5wwxIEj<7Vo1eB(x4R*`X{f5A zoZ%gK1+F{^@4^Yb@orG;Y>B@I=pPlS;HrFUT3wHVuXW!+>GHFkl!k4EzfW;Lhe`O}O{tUQ-$d3EGfmR;0#(rk!3nvEPa zD4-^_q&xxU6f^eJ8H>*v>P9ls4!%-Uu+w!bcS_z}|Lf zU0sFbYE(NXzvgTKu>pi5{n931Lor9As4Ujuf}+ z%wJ$9!rVg4CUr*4zhWsr(Na1do>1^vjMc(uJ$PgZ9@$Uxgo3aXM2Ai_&e#thNBl4& zSJMOM`VMDyRn6hrHBHIi1kHLs>?Ehpw|d8Ghf$&MT}(}9X0FUyvsTV}>l_AsCvu}+ zSa!Rw`DrHzy!x>2)?c+-{aSu*HSi<1?YFuT&~CP%@@B8?H-mmT==sf1Vx73d%39f4 ze&OtFrLeMSKPoJrFWP5M9u^kumF36h=UHq1?)|6RwIlD?4?bb{BZ=f^3C>bA@%b|Y zAH;qSPn@pr1wNk_jL4zqQ10O-&Om-|x@gf@m-}LrAq%^CALme4?)$FJJ)c?%_hb5( zgx9fkY3!<=Fi$!I({aNB$N%e-zyIqn^K2L}3|t}xL}txdD`P>!OH5xI$J#dXJ7i9* zH;AQ#ppfA>qzuO)fB#{Kx(!vv6cag!C0daF`iB7gU@-UprTATC?tl3P{aNt~MSV|L literal 0 HcmV?d00001 diff --git a/tests/test_attention_patterns.py b/tests/test_attention_patterns.py index 28f2d7b76c..05faf3a846 100644 --- a/tests/test_attention_patterns.py +++ b/tests/test_attention_patterns.py @@ -210,3 +210,83 @@ def test_alibi_pattern(): mask = AP.alibi_pattern(1e-3, (16, 128, 128)) # Minor, check that all the top left corners are True assert torch.sum(mask[:, 0, 0]) == 16 + + +def test_quick_layouts(): + + seq_size = 128 + block_size = 16 + num_heads = 2 + + # Fixed + assert torch.allclose(AP.quick_fixed_layout(num_heads, block_size, seq_size), torch.Tensor( + [[[1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1]], + + [[1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1]]]).long()) + + # BSLongformer + assert torch.allclose(AP.quick_bslongformer_layout(num_heads, block_size, seq_size), torch.Tensor( + [[[1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1]], + + [[1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1]]]).long()) + + # Variable + assert torch.allclose(AP.quick_variable_layout(num_heads, block_size, seq_size), torch.Tensor( + [[[1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1]], + + [[1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1]]]).long()) + + + +def test_layout_to_pattern(): + torch.allclose( + AP.layout_to_pattern(layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2), + torch.Tensor( + [ + [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]], + [[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]], + ] + ), +) diff --git a/xformers/.DS_Store b/xformers/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..08c08046dd2c137e61164c18b199defc47431d33 GIT binary patch literal 6148 zcmeHKL66cv82yGtD4+=kP4;5a#4B0RMHAx%+4bO6j2_e=h3Yoj?KHGVHiU$;{s;ep zSAU8B#go48o7tokNIV&%^O2c1oq2ECzM0N2M5G4u={=%05qU_A)h4nX!R?&3qGD^V z1C{gQr$N4-PjAyU;>V?*Ea8(`$FQafVYTrE3Nj{E}a3+Fi8iMlXRg_HiVxZ^A zG!yO6Gbb)AFZNsmFO_l5C@hUL&y6pIxlzZ{hT!MBdM2I~MRpCmnan-Rip`#7 zHWaxY&B2cFhC82H{3!BqD;Iklurt$Z?P+(_&H?AZ|LXv+4?YrO-{RV!EFH+?2mq|2 zS^{PMv%nhPV&CH0AX*U0NP$MG>=8p5IqH4O>swqKG;&h*@S*ITl|7*-xjXvzg-)t( z(B;kn=fI`|2j;TP`~TDZ=l@NSyK)XV2mUDsRP}UtI>0O0y>;v5c(3)5-XU=?Z*5Q{ k$n1731>TAektEROa{<`5xHgCyg!>~PHn_|=@JAi^1;ZxNYybcN literal 0 HcmV?d00001 diff --git a/xformers/components/.DS_Store b/xformers/components/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..83ebe1c02127edc9322d3639a80252871d66837b GIT binary patch literal 6148 zcmeHK&2G~`5T0!Vb%=^06rdL*OI)iGXh1?-LMR8W1i=B2v18L39B3d-u-5OvP7f?v-mDin}|Fl){QI3YJ~eaostdP zvJI`oM@l6H`0df;l(lPc3^)ep3)p8)*2qBviH}q-`22AL6npLVUep}Q@8$(h!)d{9pCq8lBLtK+x?+7HXD~N zZ+Tl@%X>RK(u-jkmb2*~EM9Z%h1Oa0ryNEv(_|5}J5O|8hH0J@0!ZToCU0J)d8`)$ zJ%V27r|YkJL~y>oor>-Kj2``!JMUH|yuz3#5x+kbF!()6z1zVqmL@G%?b`ZE)S z!x5*W2(w+4FR0XTF&`Frrt?GeGO|VQ!Bj#qw5EamD)s?9 zjVY3Ov9FvC>wL(%`&wCd9r)1rGOty$=0c3F#@i}C#A~qXr-k_^;wQ%Z*hkF_WF1yf z+k%<{SqL@)E~!OwJgUf>Vn4>{GmLwLbjG-Ay>V2k-cUWa*j4GZ&h5!La11yGE*b-T zKKMwi`&KWFmZJliJOQ8^Xcj@4|12=Zx4Lij(rENRlz~DUsIpHCW#DM{tzF;hrO^gX z%051n{j#zz6eVAWe_zZ=^^JD9W56-6%D|?nw)y^ldiMK&736A;0ms0FVn8(xh6e*o z$= List[float]: # Now threshold arbitrarily, report the mask return alibi < threshold + + +def quick_fixed_layout(num_heads: int, block_size: int, seq_len:int): + config = FixedSparsityConfig(num_heads=num_heads, block=block_size) + return config.make_layout(seq_len) + + +def quick_variable_layout(num_heads: int, block_size: int, seq_len:int): + config = VariableSparsityConfig(num_heads=num_heads, block=block_size) + return config.make_layout(seq_len) + + +def quick_bigbird_layout(num_heads: int, block_size: int, seq_len:int): + config = BigBirdSparsityConfig(num_heads=num_heads, block=block_size) + return config.make_layout(seq_len) + + +def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len:int): + config = BSLongformerSparsityConfig(num_heads=num_heads, block=block_size) + return config.make_layout(seq_len) + + +def layout_to_pattern(layout: torch.Tensor, block_size: int): + r""" + create a pattern of shape [heads, seq, seq] out of a blocksparse + layout of shape [heads, seq/block_size, seq/block_size] + """ + return torch.kron(layout, torch.ones(block_size, block_size)) diff --git a/xformers/components/attention/sparsity_config.py b/xformers/components/attention/sparsity_config.py new file mode 100644 index 0000000000..47dca78755 --- /dev/null +++ b/xformers/components/attention/sparsity_config.py @@ -0,0 +1,634 @@ +""" +The code has been adopted from DeepSpeed +(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py) +""" + +import random +import torch + +class SparsityConfig: + """Abstract Configuration class to store `sparsity configuration of a self attention layer`. + It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality. + """ + def __init__(self, num_heads, block=16, different_layout_per_head=False): + """Initialize the Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. + """ + + self.num_heads = num_heads + self.block = block + self.different_layout_per_head = different_layout_per_head + self.num_layout_heads = num_heads if different_layout_per_head else 1 + + def setup_layout(self, seq_len): + """Create layout tensor for the given sequence length + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout of all head; initialized with zero + """ + + if (seq_len % self.block != 0): + raise ValueError( + f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!' + ) + num_blocks = seq_len // self.block + # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout. + layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64) + return layout + + def check_and_propagate_first_head_layout(self, layout): + """If all heads require same sparsity layout, it propagate first head layout to all heads + Arguments: + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head + """ + + if not self.different_layout_per_head: + layout[1:self.num_heads, :, :] = layout[0, :, :] + return layout + + +class DenseSparsityConfig(SparsityConfig): + """Configuration class to store `Dense` configuration. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. + """ + def __init__(self, num_heads, block=16, different_layout_per_head=False): + """Initialize the Dense Sparsity Pattern Config. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + seq_len: required: an integer determining number of attention heads of the layer. + different_layout_per_head: optional: this is just for the sake of consistency with other sparsity formats; can ignore it for DenseSparsityConfig + """ + + super().__init__(num_heads, block, different_layout_per_head) + + def make_layout(self, seq_len): + """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse. + Arguments: + seq_len: required: an integer determining the underling sequence length; must be <= max sequence length + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; for dense everything is 1 + """ + + layout = self.setup_layout(seq_len) + layout[:, :, :] = 1 + return layout + + +class FixedSparsityConfig(SparsityConfig): + """Configuration class to store `Fixed` sparsity configuration. + For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. + """ + def __init__(self, + num_heads, + block=16, + different_layout_per_head=False, + num_local_blocks=4, + num_global_blocks=1, + attention='bidirectional', + horizontal_global_attention=False, + num_different_global_patterns=1): + """Initialize `Fixed` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. + num_local_blocks: optional: an integer determining the number of blocks in local attention window. + num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks. + num_different_global_patterns: optional: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different versions in which the first, Second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on num_local_blocks and num_global_blocks. + """ + + super().__init__(num_heads, block, different_layout_per_head) + + self.num_local_blocks = num_local_blocks + + if (num_local_blocks % num_global_blocks != 0): + raise ValueError( + f'Number of blocks in a local window, {num_local_blocks}, must be dividable by number of global blocks, {num_global_blocks}!' + ) + self.num_global_blocks = num_global_blocks + + if (attention != 'unidirectional' and attention != 'bidirectional'): + raise NotImplementedError( + 'only \"uni/bi-directional\" attentions are supported for now!') + self.attention = attention + + if (attention != 'bidirectional' and horizontal_global_attention): + raise ValueError( + 'only \"bi-directional\" attentions can support horizontal global attention!' + ) + self.horizontal_global_attention = horizontal_global_attention + + if (num_different_global_patterns > 1 and not different_layout_per_head): + raise ValueError( + f'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.' + ) + if (num_different_global_patterns > (num_local_blocks // num_global_blocks)): + raise ValueError( + f'Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, cannot be larger than number of local window blocks divided by number of global blocks, {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!' + ) + self.num_different_global_patterns = num_different_global_patterns + + def set_local_layout(self, h, layout): + """Sets local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set + """ + + num_blocks = layout.shape[1] + for i in range(0, num_blocks, self.num_local_blocks): + end = min(i + self.num_local_blocks, num_blocks) + for row in range(i, end): + for col in range( + i, + (row + 1 if self.attention == 'unidirectional' else end)): + layout[h, row, col] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Currently we set global blocks starting from the last block of a local window to the first one. That means if a local window consists of 4 blocks and global attention size is one block, we use block #4 in each local window as global. If we have different layout per head, then other heads will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global attentions, multiple head may have same global attentions. + Note) if horizontal_global_attention is set, global blocks will be set both horizontally and vertically. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + first_global_block_idx = self.num_local_blocks - ( + 1 + h % self.num_different_global_patterns) * self.num_global_blocks + + # set all global blocks except the last one if (in last local window) + end = num_blocks - (num_blocks % self.num_local_blocks) + for i in range(first_global_block_idx, end, self.num_local_blocks): + + # vertical global attention + first_row = 0 if self.attention == 'bidirectional' else i + #(((i // self.num_local_blocks) + 1) * self.num_local_blocks) + #if (first_row < num_blocks): + layout[h, first_row:, i:i + self.num_global_blocks] = 1 + + # horizontal global attention; only in bidirectional attention + if (self.horizontal_global_attention): + layout[h, i:i + self.num_global_blocks, :] = 1 + + # set last global blocks; handle possible short last local window + if (end < num_blocks): + start = min(end + first_global_block_idx, + num_blocks - self.num_global_blocks) + end = start + self.num_global_blocks + + # vertical global attention + first_row = 0 if self.attention == 'bidirectional' else start + #(((start // self.num_local_blocks) + 1) * self.num_local_blocks) + #if (first_row < num_blocks): + layout[h, first_row:, start:end] = 1 + + # horizontal global attention + if (self.horizontal_global_attention): + layout[h, start:end, :] = 1 + return layout + + def make_layout(self, seq_len): + """Generates `Fixed` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_local_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class VariableSparsityConfig(SparsityConfig): + """Configuration class to store `Variable` sparsity configuration. + This layout is an extension of FixedSparsityConfig in which: + - user can set random layout; default value is zero means no random block + - user can provide a list of local block sizes + - user can provide a list of global block indices. + For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. + """ + def __init__(self, + num_heads, + block=16, + different_layout_per_head=False, + num_random_blocks=0, + local_window_blocks=[4], + global_block_indices=[0], + global_block_end_indices=None, + attention='bidirectional', + horizontal_global_attention=False): + """Initialize `Variable` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. Currently this sparsity config can only assign single layout to all heads; needs to be extended for different layout per head. + num_random_blocks: optional: an integer determining the number of random blocks in each block row. + local_window_blocks: optional: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows. + global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention. + num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks. + """ + + super().__init__(num_heads, block, different_layout_per_head) + + self.num_random_blocks = num_random_blocks + self.local_window_blocks = local_window_blocks + self.global_block_indices = global_block_indices + + if (global_block_end_indices is not None): + if (len(global_block_indices) != len(global_block_end_indices)): + raise ValueError( + f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!' + ) + for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)): + if start_idx >= end_idx: + raise ValueError( + f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!' + ) + self.global_block_end_indices = global_block_end_indices + + if (attention != 'unidirectional' and attention != 'bidirectional'): + raise NotImplementedError( + 'only \"uni/bi-directional\" attentions are supported for now!') + self.attention = attention + + if (attention != 'bidirectional' and horizontal_global_attention): + raise ValueError( + 'only \"bi-directional\" attentions can support horizontal global attention!' + ) + self.horizontal_global_attention = horizontal_global_attention + + def set_random_layout(self, h, layout): + """Sets random attention layout used by the given head in the sparse attention. + Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set + """ + + num_blocks = layout.shape[1] + if (num_blocks < self.num_random_blocks): + raise ValueError( + f'Number of random blocks, {self.num_random_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + ) + for row in range(0, num_blocks): + rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks) + layout[h, row, rnd_cols] = 1 + return layout + + def set_local_layout(self, h, layout): + """Sets local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set + """ + + num_blocks = layout.shape[1] + start_block_idx = 0 + end_block_idx = 0 + for block_size in self.local_window_blocks: + end_block_idx += block_size + end_block_idx = min(end_block_idx, num_blocks) + for row in range(start_block_idx, end_block_idx): + for col in range( + start_block_idx, + (row + 1 if self.attention == 'unidirectional' else end_block_idx)): + layout[h, row, col] = 1 + start_block_idx += block_size + + # if there is any remaining not attended part, use the lats local window block size as local window for the remaining applicable local windows + for i in range(start_block_idx, num_blocks, block_size): + end_block_idx = min(i + block_size, num_blocks) + for row in range(i, end_block_idx): + for col in range( + i, + (row + 1 if self.attention == 'unidirectional' else end_block_idx)): + layout[h, row, col] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if (self.global_block_end_indices is None): + for idx in self.global_block_indices: + # if global block idx is in the range of the sequence blocks + if (idx < num_blocks): + #global rows + if (self.horizontal_global_attention): + layout[h, idx, :] = 1 + + #global columns + first_row = 0 if self.attention == 'bidirectional' else idx + layout[h, first_row:, idx] = 1 + else: + for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)): + # if global block idx is in the range of the sequence blocks + if (start_idx < num_blocks): + end_idx = min(end_idx, num_blocks) + #global rows + if (self.horizontal_global_attention): + layout[h, start_idx:end_idx, :] = 1 + + #global columns + first_row = 0 if self.attention == 'bidirectional' else start_idx + layout[h, first_row:, start_idx:end_idx] = 1 + return layout + + def make_layout(self, seq_len): + """Generates `Variable` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_random_layout(h, layout) + layout = self.set_local_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class BigBirdSparsityConfig(SparsityConfig): + """Configuration class to store `BigBird` sparsity configuration. + For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf + This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity. + """ + def __init__(self, + num_heads, + block=16, + different_layout_per_head=False, + num_random_blocks=1, + num_sliding_window_blocks=3, + num_global_blocks=1, + attention='bidirectional'): + """Initialize the BigBird Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. + num_random_blocks: optional: an integer determining the number of random blocks in each block row. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window. + num_global_blocks: optional: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well. + attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. + """ + + super().__init__(num_heads, block, different_layout_per_head) + + self.num_random_blocks = num_random_blocks + self.num_sliding_window_blocks = num_sliding_window_blocks + self.num_global_blocks = num_global_blocks + + if (attention != 'unidirectional' and attention != 'bidirectional'): + raise NotImplementedError( + 'only \"uni/bi-directional\" attentions are supported for now!') + self.attention = attention + + def set_random_layout(self, h, layout): + """Sets random attention layout used by the given head in the sparse attention. + Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set + """ + + num_blocks = layout.shape[1] + if (num_blocks < self.num_random_blocks): + raise ValueError( + f'Number of random blocks, {self.num_random_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + ) + + for row in range(0, num_blocks): + sample_range = range( + 0, + num_blocks) if self.attention == 'bidirectional' else range(0, + row + 1) + rnd_cols = random.sample(sample_range, self.num_random_blocks) + layout[h, row, rnd_cols] = 1 + return layout + + def set_sliding_window_layout(self, h, layout): + """Sets sliding local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set + """ + + num_blocks = layout.shape[1] + if (num_blocks < self.num_sliding_window_blocks): + raise ValueError( + f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + ) + + w = self.num_sliding_window_blocks // 2 + for row in range(0, num_blocks): + start = max(0, row - w) + end = min(row + w + 1, num_blocks) + layout[h, row, start:end] = 1 + return layout + + def set_global_layout_itc(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if (num_blocks < self.num_global_blocks): + raise ValueError( + f'Number of global blocks, {self.num_global_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + ) + + #global rows + layout[h, 0:self.num_global_blocks, :] = 1 + + #global columns + layout[h, :, 0:self.num_global_blocks] = 1 + + if self.attention == 'unidirectional': + # zero out anything attending to the future + layout = torch.tril(layout) + + return layout + + def make_layout(self, seq_len): + """Generates `BigBird` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_random_layout(h, layout) + layout = self.set_sliding_window_layout(h, layout) + layout = self.set_global_layout_itc(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class BSLongformerSparsityConfig(SparsityConfig): + """Configuration class to store edited `Longformer` sparsity configuration. + Note) this is a block-sparse version of the Longformer which is slightly different than original Longformer; which is element-wise sparsity. + For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf + This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity. + """ + def __init__(self, + num_heads, + block=16, + different_layout_per_head=False, + num_sliding_window_blocks=3, + global_block_indices=[0], + global_block_end_indices=None, + attention='bidirectional'): + """Initialize the edited `Longformer` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window. + global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. + """ + + super().__init__(num_heads, block, different_layout_per_head) + + self.num_sliding_window_blocks = num_sliding_window_blocks + self.global_block_indices = global_block_indices + self.attention = attention + + if (global_block_end_indices is not None): + if (len(global_block_indices) != len(global_block_end_indices)): + raise ValueError( + f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!' + ) + for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)): + if start_idx >= end_idx: + raise ValueError( + f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!' + ) + self.global_block_end_indices = global_block_end_indices + + def set_sliding_window_layout(self, h, layout): + """Sets sliding local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set + """ + + num_blocks = layout.shape[1] + if (num_blocks < self.num_sliding_window_blocks): + raise ValueError( + f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + ) + + w = self.num_sliding_window_blocks // 2 + for row in range(0, num_blocks): + start = max(0, row - w) + end = min(row + w + 1, num_blocks) + layout[h, row, start:end] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if (self.global_block_end_indices is None): + for idx in self.global_block_indices: + # if global block idx is in the range of the sequence blocks + if (idx < num_blocks): + #global rows + layout[h, idx, :] = 1 + + #global columns + layout[h, :, idx] = 1 + else: + for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)): + # if global block idx is in the range of the sequence blocks + if (start_idx < num_blocks): + end_idx = min(end_idx, num_blocks) + #global rows + layout[h, start_idx:end_idx, :] = 1 + + #global columns + layout[h, :, start_idx:end_idx] = 1 + if self.attention == 'unidirectional': + layout = torch.tril(layout) + return layout + + def make_layout(self, seq_len): + """Generates edited `Longformer` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_sliding_window_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout From 9a5f696df1505de830cb0fd7218b1a3f085cbb8c Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 15:55:25 -0700 Subject: [PATCH 2/8] pre-commit linting fixed linting errors fixed for std-blocksparsity-attention-layouts --- tests/test_attention_patterns.py | 164 +++-- .../benchmarks/LRA/code/config_nystrom.json | 4 +- xformers/components/attention/.DS_Store | Bin 0 -> 8196 bytes .../attention/attention_patterns.py | 20 +- .../components/attention/sparsity_config.py | 571 ++++++++++++------ 5 files changed, 489 insertions(+), 270 deletions(-) create mode 100644 xformers/components/attention/.DS_Store diff --git a/tests/test_attention_patterns.py b/tests/test_attention_patterns.py index 05faf3a846..68d498c023 100644 --- a/tests/test_attention_patterns.py +++ b/tests/test_attention_patterns.py @@ -219,74 +219,112 @@ def test_quick_layouts(): num_heads = 2 # Fixed - assert torch.allclose(AP.quick_fixed_layout(num_heads, block_size, seq_size), torch.Tensor( - [[[1, 1, 1, 1, 0, 0, 0, 1], - [1, 1, 1, 1, 0, 0, 0, 1], - [1, 1, 1, 1, 0, 0, 0, 1], - [1, 1, 1, 1, 0, 0, 0, 1], - [0, 0, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 1, 1]], - - [[1, 1, 1, 1, 0, 0, 0, 1], - [1, 1, 1, 1, 0, 0, 0, 1], - [1, 1, 1, 1, 0, 0, 0, 1], - [1, 1, 1, 1, 0, 0, 0, 1], - [0, 0, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 1, 1]]]).long()) + assert torch.allclose( + AP.quick_fixed_layout(num_heads, block_size, seq_size), + torch.Tensor( + [ + [ + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + ], + [ + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1], + ], + ] + ).long(), + ) # BSLongformer - assert torch.allclose(AP.quick_bslongformer_layout(num_heads, block_size, seq_size), torch.Tensor( - [[[1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 1, 1, 1, 0, 0], - [1, 0, 0, 0, 1, 1, 1, 0], - [1, 0, 0, 0, 0, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 1, 1]], - - [[1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 1, 1, 1, 0, 0], - [1, 0, 0, 0, 1, 1, 1, 0], - [1, 0, 0, 0, 0, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 1, 1]]]).long()) + assert torch.allclose( + AP.quick_bslongformer_layout(num_heads, block_size, seq_size), + torch.Tensor( + [ + [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1], + ], + [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1], + ], + ] + ).long(), + ) # Variable - assert torch.allclose(AP.quick_variable_layout(num_heads, block_size, seq_size), torch.Tensor( - [[[1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 1, 1, 1, 1]], - - [[1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 1, 1, 1, 1]]]).long()) - + assert torch.allclose( + AP.quick_variable_layout(num_heads, block_size, seq_size), + torch.Tensor( + [ + [ + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + ], + [ + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + ], + ] + ).long(), + ) def test_layout_to_pattern(): torch.allclose( - AP.layout_to_pattern(layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2), - torch.Tensor( - [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]], - [[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]], - ] - ), -) + AP.layout_to_pattern( + layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2 + ), + torch.Tensor( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + ], + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + ], + ] + ), + ) diff --git a/xformers/benchmarks/LRA/code/config_nystrom.json b/xformers/benchmarks/LRA/code/config_nystrom.json index d5b454264c..4bcc808411 100644 --- a/xformers/benchmarks/LRA/code/config_nystrom.json +++ b/xformers/benchmarks/LRA/code/config_nystrom.json @@ -50,7 +50,7 @@ "hidden_layer_multiplier": 2 } } - + ], "extra_settings": { "attention": { @@ -199,7 +199,7 @@ "hidden_layer_multiplier": 2 } } - + ], "extra_settings": { "attention": { diff --git a/xformers/components/attention/.DS_Store b/xformers/components/attention/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..42fec3046492142695ed15a682d18f74110e1e99 GIT binary patch literal 8196 zcmeI1&u-H|5XNUqq?)!0rBX#NNPdD!yg{hSfin+~<{!aos3Rvq@g)r0IR_d+i|1yh1wgB89AF z$ls+b_q5YQVjwY)7)T5x1`-4R0|R`swPhXOeSbHNiGjqxlVm{O4>?`sAjqAU_Ud4; z5&+p_+8Wo?0fxzg90a-Z(xBM%)q@LA7fKB0;oQ$v9XSYc=cR{>^KfxtWfv+GXRCwD zOO^M{779^kmH z-9vjtPA1KK8k2RsOg<9D7ksK>UO45eu7865%y`~_n1#G8Wz||onXBXa;7WE)MdUL^ z&37SRV{007b?m9kjdjh`0^}M?w^-@Ou`=iSHRfJ`8nND3rpTfF+L{+J-r21U^a=X< zl&Jp|R@D0lPMK}sGH|+6I`XT`HuR5JV@uA~*w=j$TLV@bn9O`y%O_++Uzu#^AKC6J zLqzi!#@PE-WA8csjPPyR<&`L7b;LPQ__$}FhKAG=!lgwvWfx3{GUBgsShsYBrN2*^ zo-%H+qEC@#qhMFWY$!){8=1YU3S#bjBQ*5cxTz%u9xVgUd}J3||G(=#|9`acr@#^e zi2**6z3UIxS1=K3>pf1i*3RiSbZx4)^U@#;R*oaB97o>&!;tG7RmBW~+7+o`$E*z#nE%bgKXW literal 0 HcmV?d00001 diff --git a/xformers/components/attention/attention_patterns.py b/xformers/components/attention/attention_patterns.py index 07baf3e099..24dc38482a 100644 --- a/xformers/components/attention/attention_patterns.py +++ b/xformers/components/attention/attention_patterns.py @@ -10,7 +10,13 @@ import numpy as np import torch -from xformers.components.attention.sparsity_config import FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig +from xformers.components.attention.sparsity_config import ( + BigBirdSparsityConfig, + BSLongformerSparsityConfig, + FixedSparsityConfig, + VariableSparsityConfig, +) + # generic nd cases def _generate_nd_grid(*sizes): @@ -261,29 +267,29 @@ def get_slopes_power_of_2(n: int) -> List[float]: return alibi < threshold -def quick_fixed_layout(num_heads: int, block_size: int, seq_len:int): +def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int): config = FixedSparsityConfig(num_heads=num_heads, block=block_size) return config.make_layout(seq_len) -def quick_variable_layout(num_heads: int, block_size: int, seq_len:int): +def quick_variable_layout(num_heads: int, block_size: int, seq_len: int): config = VariableSparsityConfig(num_heads=num_heads, block=block_size) return config.make_layout(seq_len) -def quick_bigbird_layout(num_heads: int, block_size: int, seq_len:int): +def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int): config = BigBirdSparsityConfig(num_heads=num_heads, block=block_size) return config.make_layout(seq_len) -def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len:int): +def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int): config = BSLongformerSparsityConfig(num_heads=num_heads, block=block_size) return config.make_layout(seq_len) def layout_to_pattern(layout: torch.Tensor, block_size: int): r""" - create a pattern of shape [heads, seq, seq] out of a blocksparse - layout of shape [heads, seq/block_size, seq/block_size] + create a pattern of shape [heads, seq, seq] out of a blocksparse + layout of shape [heads, seq/block_size, seq/block_size] """ return torch.kron(layout, torch.ones(block_size, block_size)) diff --git a/xformers/components/attention/sparsity_config.py b/xformers/components/attention/sparsity_config.py index 47dca78755..f0081fd8cd 100644 --- a/xformers/components/attention/sparsity_config.py +++ b/xformers/components/attention/sparsity_config.py @@ -4,19 +4,27 @@ """ import random + import torch + class SparsityConfig: """Abstract Configuration class to store `sparsity configuration of a self attention layer`. - It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality. + It contains shared property of different block-sparse sparsity patterns. However, each class + needs to extend it based on required property and functionality. """ + def __init__(self, num_heads, block=16, different_layout_per_head=False): """Initialize the Sparsity Pattern Config. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. - different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. + block: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be + assigned a different sparsity layout; default is false and this will be satisfied + based on availability. """ self.num_heads = num_heads @@ -29,42 +37,51 @@ def setup_layout(self, seq_len): Arguments: seq_len: required: an integer determining number of attention heads of the layer. Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout of all head; initialized with zero + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout + of all head; initialized with zero """ - if (seq_len % self.block != 0): + if seq_len % self.block != 0: raise ValueError( - f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!' + f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!" ) num_blocks = seq_len // self.block # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout. - layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64) + layout = torch.zeros( + (self.num_heads, num_blocks, num_blocks), dtype=torch.int64 + ) return layout def check_and_propagate_first_head_layout(self, layout): """If all heads require same sparsity layout, it propagate first head layout to all heads Arguments: - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head """ if not self.different_layout_per_head: - layout[1:self.num_heads, :, :] = layout[0, :, :] + layout[1 : self.num_heads, :, :] = layout[0, :, :] return layout class DenseSparsityConfig(SparsityConfig): """Configuration class to store `Dense` configuration. - In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and + comprehension. """ + def __init__(self, num_heads, block=16, different_layout_per_head=False): """Initialize the Dense Sparsity Pattern Config. - In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison + and comprehension. Arguments: num_heads: required: an integer determining number of attention heads of the layer. seq_len: required: an integer determining number of attention heads of the layer. - different_layout_per_head: optional: this is just for the sake of consistency with other sparsity formats; can ignore it for DenseSparsityConfig + different_layout_per_head: optional: this is just for the sake of consistency with + other sparsity formats; can ignore it for DenseSparsityConfig """ super().__init__(num_heads, block, different_layout_per_head) @@ -72,9 +89,11 @@ def __init__(self, num_heads, block=16, different_layout_per_head=False): def make_layout(self, seq_len): """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse. Arguments: - seq_len: required: an integer determining the underling sequence length; must be <= max sequence length + seq_len: required: an integer determining the underling sequence length; + must be <= max sequence length Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; for dense everything is 1 + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head; for dense everything is 1 """ layout = self.setup_layout(seq_len) @@ -84,59 +103,88 @@ def make_layout(self, seq_len): class FixedSparsityConfig(SparsityConfig): """Configuration class to store `Fixed` sparsity configuration. - For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + For more details about this sparsity config, please see `Generative Modeling with + Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. """ - def __init__(self, - num_heads, - block=16, - different_layout_per_head=False, - num_local_blocks=4, - num_global_blocks=1, - attention='bidirectional', - horizontal_global_attention=False, - num_different_global_patterns=1): + + def __init__( + self, + num_heads, + block=16, + different_layout_per_head=False, + num_local_blocks=4, + num_global_blocks=1, + attention="bidirectional", + horizontal_global_attention=False, + num_different_global_patterns=1, + ): """Initialize `Fixed` Sparsity Pattern Config. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. - different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. - num_local_blocks: optional: an integer determining the number of blocks in local attention window. - num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention. - attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. - horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks. - num_different_global_patterns: optional: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different versions in which the first, Second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on num_local_blocks and num_global_blocks. + block: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be + assigned a different sparsity layout; default is false and this will be satisfied + based on availability. + num_local_blocks: optional: an integer determining the number of blocks in local attention + window. + num_global_blocks: optional: an integer determining how many consecutive blocks in a local + window is used as the representative of the window for global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global + representative of a local window, also attend to all other blocks. This is valid only if + attention type is `bidirectional`. Looking at the attention matrix, that means global + attention not only includes the vertical blocks, but also horizontal blocks. + num_different_global_patterns: optional: an integer determining number of different global + attentions layouts. While global attention can be fixed by which block/s are representative + of any local window, since there are multi-heads, each head can use a different global representative. + For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different + versions in which the first, Second, third, or forth block of each local window can be global + representative of that window. This parameter determines how many of such patterns we want. + Of course, there is a limitation based on num_local_blocks and num_global_blocks. """ super().__init__(num_heads, block, different_layout_per_head) self.num_local_blocks = num_local_blocks - if (num_local_blocks % num_global_blocks != 0): + if num_local_blocks % num_global_blocks != 0: raise ValueError( - f'Number of blocks in a local window, {num_local_blocks}, must be dividable by number of global blocks, {num_global_blocks}!' + f"""Number of blocks in a local window, {num_local_blocks}, + must be dividable by number of global blocks, {num_global_blocks}!""" ) self.num_global_blocks = num_global_blocks - if (attention != 'unidirectional' and attention != 'bidirectional'): + if attention != "unidirectional" and attention != "bidirectional": raise NotImplementedError( - 'only \"uni/bi-directional\" attentions are supported for now!') + 'only "uni/bi-directional" attentions are supported for now!' + ) self.attention = attention - if (attention != 'bidirectional' and horizontal_global_attention): + if attention != "bidirectional" and horizontal_global_attention: raise ValueError( - 'only \"bi-directional\" attentions can support horizontal global attention!' + 'only "bi-directional" attentions can support horizontal global attention!' ) self.horizontal_global_attention = horizontal_global_attention - if (num_different_global_patterns > 1 and not different_layout_per_head): + if num_different_global_patterns > 1 and not different_layout_per_head: raise ValueError( - f'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.' + """Number of different layouts cannot be more than one when you have set a single layout + for all heads! Set different_layout_per_head to True.""" ) - if (num_different_global_patterns > (num_local_blocks // num_global_blocks)): + if num_different_global_patterns > (num_local_blocks // num_global_blocks): raise ValueError( - f'Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, cannot be larger than number of local window blocks divided by number of global blocks, {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!' + f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, + cannot be larger than number of local window blocks divided by number of global blocks, + {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!""" ) self.num_different_global_patterns = num_different_global_patterns @@ -144,9 +192,11 @@ def set_local_layout(self, h, layout): """Sets local attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local layout is set """ num_blocks = layout.shape[1] @@ -154,54 +204,64 @@ def set_local_layout(self, h, layout): end = min(i + self.num_local_blocks, num_blocks) for row in range(i, end): for col in range( - i, - (row + 1 if self.attention == 'unidirectional' else end)): + i, (row + 1 if self.attention == "unidirectional" else end) + ): layout[h, row, col] = 1 return layout def set_global_layout(self, h, layout): """Sets global attention layout used by the given head in the sparse attention. - Currently we set global blocks starting from the last block of a local window to the first one. That means if a local window consists of 4 blocks and global attention size is one block, we use block #4 in each local window as global. If we have different layout per head, then other heads will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global attentions, multiple head may have same global attentions. - Note) if horizontal_global_attention is set, global blocks will be set both horizontally and vertically. + Currently we set global blocks starting from the last block of a local window to the first one. + That means if a local window consists of 4 blocks and global attention size is one block, we use + block #4 in each local window as global. If we have different layout per head, then other heads + will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global + attentions, multiple head may have same global attentions. + Note) if horizontal_global_attention is set, global blocks will be set both horizontally and + vertically. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set """ num_blocks = layout.shape[1] - first_global_block_idx = self.num_local_blocks - ( - 1 + h % self.num_different_global_patterns) * self.num_global_blocks + first_global_block_idx = ( + self.num_local_blocks + - (1 + h % self.num_different_global_patterns) * self.num_global_blocks + ) # set all global blocks except the last one if (in last local window) end = num_blocks - (num_blocks % self.num_local_blocks) for i in range(first_global_block_idx, end, self.num_local_blocks): # vertical global attention - first_row = 0 if self.attention == 'bidirectional' else i - #(((i // self.num_local_blocks) + 1) * self.num_local_blocks) - #if (first_row < num_blocks): - layout[h, first_row:, i:i + self.num_global_blocks] = 1 + first_row = 0 if self.attention == "bidirectional" else i + # (((i // self.num_local_blocks) + 1) * self.num_local_blocks) + # if (first_row < num_blocks): + layout[h, first_row:, i : i + self.num_global_blocks] = 1 # horizontal global attention; only in bidirectional attention - if (self.horizontal_global_attention): - layout[h, i:i + self.num_global_blocks, :] = 1 + if self.horizontal_global_attention: + layout[h, i : i + self.num_global_blocks, :] = 1 # set last global blocks; handle possible short last local window - if (end < num_blocks): - start = min(end + first_global_block_idx, - num_blocks - self.num_global_blocks) + if end < num_blocks: + start = min( + end + first_global_block_idx, num_blocks - self.num_global_blocks + ) end = start + self.num_global_blocks # vertical global attention - first_row = 0 if self.attention == 'bidirectional' else start - #(((start // self.num_local_blocks) + 1) * self.num_local_blocks) - #if (first_row < num_blocks): + first_row = 0 if self.attention == "bidirectional" else start + # (((start // self.num_local_blocks) + 1) * self.num_local_blocks) + # if (first_row < num_blocks): layout[h, first_row:, start:end] = 1 # horizontal global attention - if (self.horizontal_global_attention): + if self.horizontal_global_attention: layout[h, start:end, :] = 1 return layout @@ -210,7 +270,8 @@ def make_layout(self, seq_len): Arguments: seq_len: required: an integer determining number of attention heads of the layer. Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` sparsity layout of all head + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` + sparsity layout of all head """ layout = self.setup_layout(seq_len) @@ -228,32 +289,61 @@ class VariableSparsityConfig(SparsityConfig): - user can set random layout; default value is zero means no random block - user can provide a list of local block sizes - user can provide a list of global block indices. - For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + For more details about `Fixed` sparsity config, please see `Generative Modeling with + Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. """ - def __init__(self, - num_heads, - block=16, - different_layout_per_head=False, - num_random_blocks=0, - local_window_blocks=[4], - global_block_indices=[0], - global_block_end_indices=None, - attention='bidirectional', - horizontal_global_attention=False): + + def __init__( + self, + num_heads, + block=16, + different_layout_per_head=False, + num_random_blocks=0, + local_window_blocks=[4], + global_block_indices=[0], + global_block_end_indices=None, + attention="bidirectional", + horizontal_global_attention=False, + ): """Initialize `Variable` Sparsity Pattern Config. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. - different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. Currently this sparsity config can only assign single layout to all heads; needs to be extended for different layout per head. + block: optional: an integer determining the block size. Current implementation of sparse + self-attention is based on blocked sparse matrices. In which this parameter defines + size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a + different sparsity layout; default is false and this will be satisfied based on + availability. Currently this sparsity config can only assign single layout to all heads; + needs to be extended for different layout per head. num_random_blocks: optional: an integer determining the number of random blocks in each block row. - local_window_blocks: optional: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows. - global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window. - global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention. - num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention. - attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. - horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks. + local_window_blocks: optional: a list of integers determining the number of blocks in each + local attention window. It assumes first number determines # of blocks in the first local + window, second the second window, ..., and the last number determines the number of blocks + in the remaining local windows. + global_block_indices: optional: a list of integers determining which blocks are considered + as global attention. Given indices, determine the blocks that all other token blocks + attend to and they attend to all other token blocks. Default value is only index 0. + Notice that if global_block_end_indices parameter is set, this parameter is used as + starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global + window blocks. By default this is not used. But if it is set, it must have the same size + of global_block_indices parameter, and combining this two parameters, for each index i, + blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are + considered as global attention. + num_global_blocks: optional: an integer determining how many consecutive blocks in a local + window is used as the representative of the window for global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global + representative of a local window, also attend to all other blocks. This is valid only if + attention type is `bidirectional`. Looking at the attention matrix, that means global + attention not only includes the vertical blocks, but also horizontal blocks. """ super().__init__(num_heads, block, different_layout_per_head) @@ -262,43 +352,53 @@ def __init__(self, self.local_window_blocks = local_window_blocks self.global_block_indices = global_block_indices - if (global_block_end_indices is not None): - if (len(global_block_indices) != len(global_block_end_indices)): + if global_block_end_indices is not None: + if len(global_block_indices) != len(global_block_end_indices): raise ValueError( - f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!' + f"""Global block start indices length, {len(global_block_indices)}, must be same as + global block end indices length, {len(global_block_end_indices)}!""" ) - for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)): + for _, (start_idx, end_idx) in enumerate( + zip(global_block_indices, global_block_end_indices) + ): if start_idx >= end_idx: raise ValueError( - f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!' + f"""Global block start index, {start_idx}, must be smaller than global block end + index, {end_idx}!""" ) self.global_block_end_indices = global_block_end_indices - if (attention != 'unidirectional' and attention != 'bidirectional'): + if attention != "unidirectional" and attention != "bidirectional": raise NotImplementedError( - 'only \"uni/bi-directional\" attentions are supported for now!') + 'only "uni/bi-directional" attentions are supported for now!' + ) self.attention = attention - if (attention != 'bidirectional' and horizontal_global_attention): + if attention != "bidirectional" and horizontal_global_attention: raise ValueError( - 'only \"bi-directional\" attentions can support horizontal global attention!' + 'only "bi-directional" attentions can support horizontal global attention!' ) self.horizontal_global_attention = horizontal_global_attention def set_random_layout(self, h, layout): """Sets random attention layout used by the given head in the sparse attention. - Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout. + Note) By default, it assumes there will be a unique random block layout for all heads; unless + `different_layout_per_head` parameter is set in which each head can have a different random + layout. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which random layout is set """ num_blocks = layout.shape[1] - if (num_blocks < self.num_random_blocks): + if num_blocks < self.num_random_blocks: raise ValueError( - f'Number of random blocks, {self.num_random_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" ) for row in range(0, num_blocks): rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks) @@ -309,9 +409,11 @@ def set_local_layout(self, h, layout): """Sets local attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local layout is set """ num_blocks = layout.shape[1] @@ -322,18 +424,21 @@ def set_local_layout(self, h, layout): end_block_idx = min(end_block_idx, num_blocks) for row in range(start_block_idx, end_block_idx): for col in range( - start_block_idx, - (row + 1 if self.attention == 'unidirectional' else end_block_idx)): + start_block_idx, + (row + 1 if self.attention == "unidirectional" else end_block_idx), + ): layout[h, row, col] = 1 start_block_idx += block_size - # if there is any remaining not attended part, use the lats local window block size as local window for the remaining applicable local windows + # if there is any remaining not attended part, use the lats local window block size as local + # window for the remaining applicable local windows for i in range(start_block_idx, num_blocks, block_size): end_block_idx = min(i + block_size, num_blocks) for row in range(i, end_block_idx): for col in range( - i, - (row + 1 if self.attention == 'unidirectional' else end_block_idx)): + i, + (row + 1 if self.attention == "unidirectional" else end_block_idx), + ): layout[h, row, col] = 1 return layout @@ -341,34 +446,38 @@ def set_global_layout(self, h, layout): """Sets global attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set """ num_blocks = layout.shape[1] - if (self.global_block_end_indices is None): + if self.global_block_end_indices is None: for idx in self.global_block_indices: # if global block idx is in the range of the sequence blocks - if (idx < num_blocks): - #global rows - if (self.horizontal_global_attention): + if idx < num_blocks: + # global rows + if self.horizontal_global_attention: layout[h, idx, :] = 1 - #global columns - first_row = 0 if self.attention == 'bidirectional' else idx + # global columns + first_row = 0 if self.attention == "bidirectional" else idx layout[h, first_row:, idx] = 1 else: - for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)): + for _, (start_idx, end_idx) in enumerate( + zip(self.global_block_indices, self.global_block_end_indices) + ): # if global block idx is in the range of the sequence blocks - if (start_idx < num_blocks): + if start_idx < num_blocks: end_idx = min(end_idx, num_blocks) - #global rows - if (self.horizontal_global_attention): + # global rows + if self.horizontal_global_attention: layout[h, start_idx:end_idx, :] = 1 - #global columns - first_row = 0 if self.attention == 'bidirectional' else start_idx + # global columns + first_row = 0 if self.attention == "bidirectional" else start_idx layout[h, first_row:, start_idx:end_idx] = 1 return layout @@ -377,7 +486,8 @@ def make_layout(self, seq_len): Arguments: seq_len: required: an integer determining number of attention heads of the layer. Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` sparsity layout of all head + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` + sparsity layout of all head """ layout = self.setup_layout(seq_len) @@ -392,27 +502,44 @@ def make_layout(self, seq_len): class BigBirdSparsityConfig(SparsityConfig): """Configuration class to store `BigBird` sparsity configuration. - For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf + For more details about this sparsity config, please see `Big Bird: Transformers for + Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity. """ - def __init__(self, - num_heads, - block=16, - different_layout_per_head=False, - num_random_blocks=1, - num_sliding_window_blocks=3, - num_global_blocks=1, - attention='bidirectional'): + + def __init__( + self, + num_heads, + block=16, + different_layout_per_head=False, + num_random_blocks=1, + num_sliding_window_blocks=3, + num_global_blocks=1, + attention="bidirectional", + ): """Initialize the BigBird Sparsity Pattern Config. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. - different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. - num_random_blocks: optional: an integer determining the number of random blocks in each block row. - num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window. - num_global_blocks: optional: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well. - attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. + block: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned + a different sparsity layout; default is false and this will be satisfied based on + availability. + num_random_blocks: optional: an integer determining the number of random blocks in each + block row. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding + local attention window. + num_global_blocks: optional: an integer determining how many consecutive blocks, starting + from index 0, are considered as global attention. Global block tokens will be attended + by all other block tokens and will attend to all other block tokens as well. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. """ super().__init__(num_heads, block, different_layout_per_head) @@ -421,32 +548,38 @@ def __init__(self, self.num_sliding_window_blocks = num_sliding_window_blocks self.num_global_blocks = num_global_blocks - if (attention != 'unidirectional' and attention != 'bidirectional'): + if attention != "unidirectional" and attention != "bidirectional": raise NotImplementedError( - 'only \"uni/bi-directional\" attentions are supported for now!') + 'only "uni/bi-directional" attentions are supported for now!' + ) self.attention = attention def set_random_layout(self, h, layout): """Sets random attention layout used by the given head in the sparse attention. - Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout. + Note) By default, it assumes there will be a unique random block layout for all heads; unless + `different_layout_per_head` parameter is set in which each head can have a different random layout. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which random layout is set """ num_blocks = layout.shape[1] - if (num_blocks < self.num_random_blocks): + if num_blocks < self.num_random_blocks: raise ValueError( - f'Number of random blocks, {self.num_random_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" ) for row in range(0, num_blocks): - sample_range = range( - 0, - num_blocks) if self.attention == 'bidirectional' else range(0, - row + 1) + sample_range = ( + range(0, num_blocks) + if self.attention == "bidirectional" + else range(0, row + 1) + ) rnd_cols = random.sample(sample_range, self.num_random_blocks) layout[h, row, rnd_cols] = 1 return layout @@ -455,15 +588,18 @@ def set_sliding_window_layout(self, h, layout): """Sets sliding local attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local sliding window layout is set """ num_blocks = layout.shape[1] - if (num_blocks < self.num_sliding_window_blocks): + if num_blocks < self.num_sliding_window_blocks: raise ValueError( - f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than + overall number of blocks in a row, {num_blocks}!""" ) w = self.num_sliding_window_blocks // 2 @@ -477,24 +613,27 @@ def set_global_layout_itc(self, h, layout): """Sets global attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout + of all head in which global layout is set """ num_blocks = layout.shape[1] - if (num_blocks < self.num_global_blocks): + if num_blocks < self.num_global_blocks: raise ValueError( - f'Number of global blocks, {self.num_global_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" ) - #global rows - layout[h, 0:self.num_global_blocks, :] = 1 + # global rows + layout[h, 0 : self.num_global_blocks, :] = 1 - #global columns - layout[h, :, 0:self.num_global_blocks] = 1 + # global columns + layout[h, :, 0 : self.num_global_blocks] = 1 - if self.attention == 'unidirectional': + if self.attention == "unidirectional": # zero out anything attending to the future layout = torch.tril(layout) @@ -505,7 +644,8 @@ def make_layout(self, seq_len): Arguments: seq_len: required: an integer determining number of attention heads of the layer. Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` sparsity layout of all head + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` + sparsity layout of all head """ layout = self.setup_layout(seq_len) @@ -520,28 +660,51 @@ def make_layout(self, seq_len): class BSLongformerSparsityConfig(SparsityConfig): """Configuration class to store edited `Longformer` sparsity configuration. - Note) this is a block-sparse version of the Longformer which is slightly different than original Longformer; which is element-wise sparsity. - For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf + Note) this is a block-sparse version of the Longformer which is slightly different than original + Longformer; which is element-wise sparsity. + For more details about this sparsity config, please see `Longformer: + The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity. """ - def __init__(self, - num_heads, - block=16, - different_layout_per_head=False, - num_sliding_window_blocks=3, - global_block_indices=[0], - global_block_end_indices=None, - attention='bidirectional'): + + def __init__( + self, + num_heads, + block=16, + different_layout_per_head=False, + num_sliding_window_blocks=3, + global_block_indices=[0], + global_block_end_indices=None, + attention="bidirectional", + ): """Initialize the edited `Longformer` Sparsity Pattern Config. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. - different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. - num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window. - global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window. - global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention. - attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure. + block: optional: an integer determining the block size. Current implementation of sparse + self-attention is based on blocked sparse matrices. In which this parameter defines size + of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a + different sparsity layout; default is false and this will be satisfied based on + availability. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding + local attention window. + global_block_indices: optional: a list of integers determining which blocks are considered + as global attention. Given indices, determine the blocks that all other token blocks + attend to and they attend to all other token blocks. Default value is only index 0. + Notice that if global_block_end_indices parameter is set, this parameter is used as + starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global + window blocks. By default this is not used. But if it is set, it must have the same size + of global_block_indices parameter, and combining this two parameters, for each index i, + blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are + considered as global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. """ super().__init__(num_heads, block, different_layout_per_head) @@ -550,15 +713,19 @@ def __init__(self, self.global_block_indices = global_block_indices self.attention = attention - if (global_block_end_indices is not None): - if (len(global_block_indices) != len(global_block_end_indices)): + if global_block_end_indices is not None: + if len(global_block_indices) != len(global_block_end_indices): raise ValueError( - f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!' + f"""Global block start indices length, {len(global_block_indices)}, must be same as + global block end indices length, {len(global_block_end_indices)}!""" ) - for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)): + for _, (start_idx, end_idx) in enumerate( + zip(global_block_indices, global_block_end_indices) + ): if start_idx >= end_idx: raise ValueError( - f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!' + f"""Global block start index, {start_idx}, must be smaller than global block end + index, {end_idx}!""" ) self.global_block_end_indices = global_block_end_indices @@ -566,15 +733,18 @@ def set_sliding_window_layout(self, h, layout): """Sets sliding local attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout + of all head in which local sliding window layout is set """ num_blocks = layout.shape[1] - if (num_blocks < self.num_sliding_window_blocks): + if num_blocks < self.num_sliding_window_blocks: raise ValueError( - f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' + f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller + than overall number of blocks in a row, {num_blocks}!""" ) w = self.num_sliding_window_blocks // 2 @@ -588,32 +758,36 @@ def set_global_layout(self, h, layout): """Sets global attention layout used by the given head in the sparse attention. Arguments: h: required: an integer determining head index - layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completely set at this step + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set """ num_blocks = layout.shape[1] - if (self.global_block_end_indices is None): + if self.global_block_end_indices is None: for idx in self.global_block_indices: # if global block idx is in the range of the sequence blocks - if (idx < num_blocks): - #global rows + if idx < num_blocks: + # global rows layout[h, idx, :] = 1 - #global columns + # global columns layout[h, :, idx] = 1 else: - for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)): + for _, (start_idx, end_idx) in enumerate( + zip(self.global_block_indices, self.global_block_end_indices) + ): # if global block idx is in the range of the sequence blocks - if (start_idx < num_blocks): + if start_idx < num_blocks: end_idx = min(end_idx, num_blocks) - #global rows + # global rows layout[h, start_idx:end_idx, :] = 1 - #global columns + # global columns layout[h, :, start_idx:end_idx] = 1 - if self.attention == 'unidirectional': + if self.attention == "unidirectional": layout = torch.tril(layout) return layout @@ -622,7 +796,8 @@ def make_layout(self, seq_len): Arguments: seq_len: required: an integer determining number of attention heads of the layer. Return: - layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` sparsity layout of all head + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` + sparsity layout of all head """ layout = self.setup_layout(seq_len) From 7a9c9a3f70b563452e48ce8aa24dfff4ffd32b2e Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 16:07:39 -0700 Subject: [PATCH 3/8] Added copyright, fixed block Added copyright fixed block -> block_size removed DeepSpeed TODO --- .../attention/attention_patterns.py | 8 ++-- .../components/attention/sparsity_config.py | 41 ++++++++++--------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/xformers/components/attention/attention_patterns.py b/xformers/components/attention/attention_patterns.py index 24dc38482a..9c817debb9 100644 --- a/xformers/components/attention/attention_patterns.py +++ b/xformers/components/attention/attention_patterns.py @@ -268,22 +268,22 @@ def get_slopes_power_of_2(n: int) -> List[float]: def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int): - config = FixedSparsityConfig(num_heads=num_heads, block=block_size) + config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size) return config.make_layout(seq_len) def quick_variable_layout(num_heads: int, block_size: int, seq_len: int): - config = VariableSparsityConfig(num_heads=num_heads, block=block_size) + config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size) return config.make_layout(seq_len) def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int): - config = BigBirdSparsityConfig(num_heads=num_heads, block=block_size) + config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size) return config.make_layout(seq_len) def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int): - config = BSLongformerSparsityConfig(num_heads=num_heads, block=block_size) + config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size) return config.make_layout(seq_len) diff --git a/xformers/components/attention/sparsity_config.py b/xformers/components/attention/sparsity_config.py index f0081fd8cd..57df975499 100644 --- a/xformers/components/attention/sparsity_config.py +++ b/xformers/components/attention/sparsity_config.py @@ -1,3 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. """ The code has been adopted from DeepSpeed (https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py) @@ -14,12 +18,11 @@ class SparsityConfig: needs to extend it based on required property and functionality. """ - def __init__(self, num_heads, block=16, different_layout_per_head=False): + def __init__(self, num_heads, block_size=16, different_layout_per_head=False): """Initialize the Sparsity Pattern Config. - For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of + block_size: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. different_layout_per_head: optional: a boolean determining if each head should be @@ -28,7 +31,7 @@ def __init__(self, num_heads, block=16, different_layout_per_head=False): """ self.num_heads = num_heads - self.block = block + self.block_size = block_size self.different_layout_per_head = different_layout_per_head self.num_layout_heads = num_heads if different_layout_per_head else 1 @@ -41,11 +44,11 @@ def setup_layout(self, seq_len): of all head; initialized with zero """ - if seq_len % self.block != 0: + if seq_len % self.block_size != 0: raise ValueError( - f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!" + f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!" ) - num_blocks = seq_len // self.block + num_blocks = seq_len // self.block_size # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout. layout = torch.zeros( (self.num_heads, num_blocks, num_blocks), dtype=torch.int64 @@ -111,7 +114,7 @@ class FixedSparsityConfig(SparsityConfig): def __init__( self, num_heads, - block=16, + block_size=16, different_layout_per_head=False, num_local_blocks=4, num_global_blocks=1, @@ -123,7 +126,7 @@ def __init__( For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of + block_size: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. different_layout_per_head: optional: a boolean determining if each head should be @@ -152,7 +155,7 @@ def __init__( Of course, there is a limitation based on num_local_blocks and num_global_blocks. """ - super().__init__(num_heads, block, different_layout_per_head) + super().__init__(num_heads, block_size, different_layout_per_head) self.num_local_blocks = num_local_blocks @@ -297,7 +300,7 @@ class VariableSparsityConfig(SparsityConfig): def __init__( self, num_heads, - block=16, + block_size=16, different_layout_per_head=False, num_random_blocks=0, local_window_blocks=[4], @@ -310,7 +313,7 @@ def __init__( For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse + block_size: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. different_layout_per_head: optional: a boolean determining if each head should be assigned a @@ -346,7 +349,7 @@ def __init__( attention not only includes the vertical blocks, but also horizontal blocks. """ - super().__init__(num_heads, block, different_layout_per_head) + super().__init__(num_heads, block_size, different_layout_per_head) self.num_random_blocks = num_random_blocks self.local_window_blocks = local_window_blocks @@ -510,7 +513,7 @@ class BigBirdSparsityConfig(SparsityConfig): def __init__( self, num_heads, - block=16, + block_size=16, different_layout_per_head=False, num_random_blocks=1, num_sliding_window_blocks=3, @@ -521,7 +524,7 @@ def __init__( For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of + block_size: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. different_layout_per_head: optional: a boolean determining if each head should be assigned @@ -542,7 +545,7 @@ def __init__( matrix is mirror of the lower triangular in the above figure. """ - super().__init__(num_heads, block, different_layout_per_head) + super().__init__(num_heads, block_size, different_layout_per_head) self.num_random_blocks = num_random_blocks self.num_sliding_window_blocks = num_sliding_window_blocks @@ -670,7 +673,7 @@ class BSLongformerSparsityConfig(SparsityConfig): def __init__( self, num_heads, - block=16, + block_size=16, different_layout_per_head=False, num_sliding_window_blocks=3, global_block_indices=[0], @@ -681,7 +684,7 @@ def __init__( For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial Arguments: num_heads: required: an integer determining number of attention heads of the layer. - block: optional: an integer determining the block size. Current implementation of sparse + block_size: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. different_layout_per_head: optional: a boolean determining if each head should be assigned a @@ -707,7 +710,7 @@ def __init__( matrix is mirror of the lower triangular in the above figure. """ - super().__init__(num_heads, block, different_layout_per_head) + super().__init__(num_heads, block_size, different_layout_per_head) self.num_sliding_window_blocks = num_sliding_window_blocks self.global_block_indices = global_block_indices From 1c54759b8440bbb4dea447041837b4eb21da817e Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 18:41:55 -0700 Subject: [PATCH 4/8] Add to changelog, extra tests --- CHANGELOG.md | 1 + tests/test_attention_patterns.py | 117 +++++++++++++++++++++++++++++++ xformers/components/.DS_Store | Bin 6148 -> 8196 bytes 3 files changed, 118 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 478c7e19e0..489d5d495e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed dupliacated biases in the FusedMLP layers [#317] ### Added +- Four blocksparsity layouts from DeepSpeed [#320] ## [0.0.11] - 2022-05-30 ### Fixed diff --git a/tests/test_attention_patterns.py b/tests/test_attention_patterns.py index 68d498c023..28d1514b3d 100644 --- a/tests/test_attention_patterns.py +++ b/tests/test_attention_patterns.py @@ -9,6 +9,12 @@ import torch import xformers.components.attention.attention_patterns as AP +from xformers.components.attention.sparsity_config import ( + BigBirdSparsityConfig, + BSLongformerSparsityConfig, + DenseSparsityConfig, + FixedSparsityConfig, +) # baseline implementations @@ -305,6 +311,11 @@ def test_quick_layouts(): ).long(), ) + # BigBird (just the shape) + assert AP.quick_variable_layout( + num_heads, block_size, seq_size + ).shape == torch.Size([num_heads, seq_size // block_size, seq_size // block_size]) + def test_layout_to_pattern(): torch.allclose( @@ -328,3 +339,109 @@ def test_layout_to_pattern(): ] ), ) + + +def test_dense_sparsity_config(): + sc = DenseSparsityConfig(num_heads=1, block_size=16) + with pytest.raises(expected_exception=ValueError): + sc.setup_layout(seq_len=17) + assert torch.allclose( + sc.make_layout(seq_len=32), torch.Tensor([[[1, 1], [1, 1]]]).long() + ) + + +def test_big_bird_sparsity_config(): + sc = BigBirdSparsityConfig( + num_heads=1, + block_size=16, + num_random_blocks=2, + num_sliding_window_blocks=1, + num_global_blocks=1, + ) + with pytest.raises(expected_exception=ValueError): + sc.make_layout(seq_len=16) + sc = BigBirdSparsityConfig( + num_heads=1, + block_size=16, + num_random_blocks=1, + num_sliding_window_blocks=2, + num_global_blocks=1, + ) + with pytest.raises(expected_exception=ValueError): + sc.make_layout(seq_len=16) + sc = BigBirdSparsityConfig( + num_heads=1, + block_size=16, + num_random_blocks=1, + num_sliding_window_blocks=1, + num_global_blocks=2, + ) + with pytest.raises(expected_exception=ValueError): + sc.make_layout(seq_len=16) + + +def test_bslongformer_sparsity_config(): + sc = BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[1]) + assert torch.allclose( + sc.make_layout(128), + torch.Tensor( + [ + [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1], + ] + ] + ).long(), + ) + with pytest.raises(expected_exception=ValueError): + BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[]) + with pytest.raises(expected_exception=ValueError): + BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[-1]) + + +def test_fixed_sparsity_config(): + # chech that the case end < num_blocks is correct + sc = FixedSparsityConfig(num_heads=1, horizontal_global_attention=True) + assert torch.allclose( + sc.make_layout(112), + torch.Tensor( + [ + [ + [1, 1, 1, 1, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 1], + [1, 1, 1, 1, 0, 0, 1], + [1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ] + ] + ).long(), + ) + with pytest.raises(expected_exception=ValueError): + FixedSparsityConfig(num_heads=1, num_local_blocks=3, num_global_blocks=2) + with pytest.raises(expected_exception=ValueError): + FixedSparsityConfig(num_heads=1, attention="directional") + with pytest.raises(expected_exception=ValueError): + FixedSparsityConfig( + num_heads=1, attention="unidirectional", horizontal_global_attention=True + ) + with pytest.raises(expected_exception=ValueError): + FixedSparsityConfig( + num_heads=1, + num_different_global_patterns=2, + different_layout_per_head=False, + ) + with pytest.raises(expected_exception=ValueError): + FixedSparsityConfig( + num_heads=1, + num_different_global_patterns=10, + num_local_blocks=4, + num_global_blocks=1, + ) diff --git a/xformers/components/.DS_Store b/xformers/components/.DS_Store index 83ebe1c02127edc9322d3639a80252871d66837b..836cdae161f63bf814d6ad9ecfbb1fc969741272 100644 GIT binary patch delta 602 zcmZWnL2DC16n<|vH5(I?-8PA_BB;3t1&QEQ8Y7}XB(V)8!P>C9nQVt{XV%$GYkNr! z{R3gen;>}dX2`J@k6!!<;!n_<3Lc!zZb9n6@O|^W_r5pGyp!5Vy#fH(7!B?KL{yVI zLOnZuS)Sg$f%#ecRGnYf{;2QMnl_!JFh9?wUgMbB2cG3}OO}7JJeT@Gt@cgVQBhM% z7t&hhVm5avU(k)^6{BdBO6ALy(v{V#t9^aNGVMJ_@DUq4;4T*pcSN2$f(|x4)27Vn z2IN(+NsmciUp|+*K?XaCG?wL^BTpt3&m}z)gs_ESYdG{t(EW$^`g*~)1$W&x_h~>m z>uvkH{ID62{U#f7+46vr!70H6~H8^J&4g;+d2nb{JDU6%@h0~c9 rvvY6=G6R(WfdDs(44dP5<}d>Qi$f4s From b942065a48a9e107000305756fe0aba400fcaa8c Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 18:59:23 -0700 Subject: [PATCH 5/8] Fix block_size for Dense SC --- xformers/components/attention/sparsity_config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xformers/components/attention/sparsity_config.py b/xformers/components/attention/sparsity_config.py index 57df975499..b69d6ce8c0 100644 --- a/xformers/components/attention/sparsity_config.py +++ b/xformers/components/attention/sparsity_config.py @@ -76,18 +76,20 @@ class DenseSparsityConfig(SparsityConfig): comprehension. """ - def __init__(self, num_heads, block=16, different_layout_per_head=False): + def __init__(self, num_heads, block_size=16, different_layout_per_head=False): """Initialize the Dense Sparsity Pattern Config. In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. Arguments: num_heads: required: an integer determining number of attention heads of the layer. - seq_len: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. different_layout_per_head: optional: this is just for the sake of consistency with other sparsity formats; can ignore it for DenseSparsityConfig """ - super().__init__(num_heads, block, different_layout_per_head) + super().__init__(num_heads, block_size, different_layout_per_head) def make_layout(self, seq_len): """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse. From c5e7415a61a6c04279e9b7813338fd4bf6a703f5 Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 19:03:27 -0700 Subject: [PATCH 6/8] Fixed a unittest --- tests/test_attention_patterns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_attention_patterns.py b/tests/test_attention_patterns.py index 28d1514b3d..8fcc4ed8cf 100644 --- a/tests/test_attention_patterns.py +++ b/tests/test_attention_patterns.py @@ -426,7 +426,7 @@ def test_fixed_sparsity_config(): ) with pytest.raises(expected_exception=ValueError): FixedSparsityConfig(num_heads=1, num_local_blocks=3, num_global_blocks=2) - with pytest.raises(expected_exception=ValueError): + with pytest.raises(expected_exception=NotImplementedError): FixedSparsityConfig(num_heads=1, attention="directional") with pytest.raises(expected_exception=ValueError): FixedSparsityConfig( From 2aed52c7e105087b31ab4685f0d2c1c6b00dded4 Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Wed, 1 Jun 2022 23:17:15 -0700 Subject: [PATCH 7/8] More tests for Variable, BigBird --- tests/test_attention_patterns.py | 34 +++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/tests/test_attention_patterns.py b/tests/test_attention_patterns.py index 8fcc4ed8cf..73756b92b2 100644 --- a/tests/test_attention_patterns.py +++ b/tests/test_attention_patterns.py @@ -14,6 +14,7 @@ BSLongformerSparsityConfig, DenseSparsityConfig, FixedSparsityConfig, + VariableSparsityConfig, ) @@ -312,9 +313,9 @@ def test_quick_layouts(): ) # BigBird (just the shape) - assert AP.quick_variable_layout( - num_heads, block_size, seq_size - ).shape == torch.Size([num_heads, seq_size // block_size, seq_size // block_size]) + assert AP.quick_bigbird_layout(num_heads, block_size, seq_size).shape == torch.Size( + [num_heads, seq_size // block_size, seq_size // block_size] + ) def test_layout_to_pattern(): @@ -378,6 +379,8 @@ def test_big_bird_sparsity_config(): ) with pytest.raises(expected_exception=ValueError): sc.make_layout(seq_len=16) + with pytest.raises(expected_exception=NotImplementedError): + BigBirdSparsityConfig(num_heads=1, attention="directional") def test_bslongformer_sparsity_config(): @@ -445,3 +448,28 @@ def test_fixed_sparsity_config(): num_local_blocks=4, num_global_blocks=1, ) + + +def test_variable_sparsity_config(): + sc = VariableSparsityConfig(num_heads=1, global_block_end_indices=[1]) + assert torch.allclose( + sc.make_layout(128), + torch.Tensor( + [ + [ + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + [1, 0, 0, 0, 1, 1, 1, 1], + ] + ] + ).long(), + ) + with pytest.raises(expected_exception=ValueError): + VariableSparsityConfig(num_heads=1, global_block_end_indices=[]) + with pytest.raises(expected_exception=ValueError): + VariableSparsityConfig(num_heads=1, global_block_end_indices=[-1]) From 878ea7273831aa78ed8b21305d938cf9cef59fa0 Mon Sep 17 00:00:00 2001 From: igormolybogFB <106693515+igormolybogFB@users.noreply.github.com> Date: Thu, 2 Jun 2022 12:14:07 -0700 Subject: [PATCH 8/8] removed .DS_Store --- .DS_Store | Bin 8196 -> 0 bytes xformers/.DS_Store | Bin 6148 -> 0 bytes xformers/components/.DS_Store | Bin 8196 -> 0 bytes xformers/components/attention/.DS_Store | Bin 8196 -> 0 bytes 4 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 xformers/.DS_Store delete mode 100644 xformers/components/.DS_Store delete mode 100644 xformers/components/attention/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index e4c136bf463067aefda298c5ee52d00322782785..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHM&2AGh5FWQF-6;G-sFf-Z(n?%Q3A7*~E=l?W5wwxIEj<7Vo1eB(x4R*`X{f5A zoZ%gK1+F{^@4^Yb@orG;Y>B@I=pPlS;HrFUT3wHVuXW!+>GHFkl!k4EzfW;Lhe`O}O{tUQ-$d3EGfmR;0#(rk!3nvEPa zD4-^_q&xxU6f^eJ8H>*v>P9ls4!%-Uu+w!bcS_z}|Lf zU0sFbYE(NXzvgTKu>pi5{n931Lor9As4Ujuf}+ z%wJ$9!rVg4CUr*4zhWsr(Na1do>1^vjMc(uJ$PgZ9@$Uxgo3aXM2Ai_&e#thNBl4& zSJMOM`VMDyRn6hrHBHIi1kHLs>?Ehpw|d8Ghf$&MT}(}9X0FUyvsTV}>l_AsCvu}+ zSa!Rw`DrHzy!x>2)?c+-{aSu*HSi<1?YFuT&~CP%@@B8?H-mmT==sf1Vx73d%39f4 ze&OtFrLeMSKPoJrFWP5M9u^kumF36h=UHq1?)|6RwIlD?4?bb{BZ=f^3C>bA@%b|Y zAH;qSPn@pr1wNk_jL4zqQ10O-&Om-|x@gf@m-}LrAq%^CALme4?)$FJJ)c?%_hb5( zgx9fkY3!<=Fi$!I({aNB$N%e-zyIqn^K2L}3|t}xL}txdD`P>!OH5xI$J#dXJ7i9* zH;AQ#ppfA>qzuO)fB#{Kx(!vv6cag!C0daF`iB7gU@-UprTATC?tl3P{aNt~MSV|L diff --git a/xformers/.DS_Store b/xformers/.DS_Store deleted file mode 100644 index 08c08046dd2c137e61164c18b199defc47431d33..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKL66cv82yGtD4+=kP4;5a#4B0RMHAx%+4bO6j2_e=h3Yoj?KHGVHiU$;{s;ep zSAU8B#go48o7tokNIV&%^O2c1oq2ECzM0N2M5G4u={=%05qU_A)h4nX!R?&3qGD^V z1C{gQr$N4-PjAyU;>V?*Ea8(`$FQafVYTrE3Nj{E}a3+Fi8iMlXRg_HiVxZ^A zG!yO6Gbb)AFZNsmFO_l5C@hUL&y6pIxlzZ{hT!MBdM2I~MRpCmnan-Rip`#7 zHWaxY&B2cFhC82H{3!BqD;Iklurt$Z?P+(_&H?AZ|LXv+4?YrO-{RV!EFH+?2mq|2 zS^{PMv%nhPV&CH0AX*U0NP$MG>=8p5IqH4O>swqKG;&h*@S*ITl|7*-xjXvzg-)t( z(B;kn=fI`|2j;TP`~TDZ=l@NSyK)XV2mUDsRP}UtI>0O0y>;v5c(3)5-XU=?Z*5Q{ k$n1731>TAektEROa{<`5xHgCyg!>~PHn_|=@JAi^1;ZxNYybcN diff --git a/xformers/components/.DS_Store b/xformers/components/.DS_Store deleted file mode 100644 index 836cdae161f63bf814d6ad9ecfbb1fc969741272..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHML2nyH6n>KiVq=nQQl||?NGp6nYE%pjNC=@Cw+JGH65=Q%YRj^_o;XX^GuH0f zCXFKb3@3g7+&OVY;=+*&S8j-(0N0-2o0)aHaU5}}P^Hd9GwUFsSOzQuPXhyZ zXRA`W;Jt4~ZEG2@3_OzzsPn-^RkD%Gk&*i7K%-Lt$XRsj0v&mP#<;m`WTb%;GjL-0%q-7Pn4TRtSDO=S7-?I}fMwu311fi~=$t(w zaVCF1-iT!w#4><|3aa-}Z~UVRG@=FaDMY>>(Eyr~5*|{W`%hwjoWPbfA|nn&jyyXX zd*#0o!7%OR>-#C_#Zj8o>%X&^OS8{CKUbP7RZ90;hjQ4;{H&ig{p7B?+L1C2#=0Nu zMttZw)!R~Jek6FJ0a2U7rGL}g*?(@5yBx<*Mei(_) zUdF%5TG0dUIm;8K+r0hxCi$W>P z$``d;Ty7;?xvbs#+Ynl7d^}cOGm1ksqY%{^JHKeXR2_HalI;j7c)8ZtG96{Simu~v z|Acnwh#t_l^d0>`Khv-D2mM74*&K7&tL!>kXYa9F>?8Im+h!p2ZJvzm3tResf0=WlUNnb(>bwFl7 zs}YWV^;vErG}*%mKUb(reTsEhdUOb!kCk|+5rVs?12;tf6kjuH2gNll7wkQF>QbQh zrF;qMO!J|SnePk7Ocg#PW>%jT%@N~cbyK{Z<%iUfQ}k28{<-$k#{QI#DLc@|P7yV) zQ3D_ojSYa~Q<3U^6p>f8eHW{bvF;G{f^|>zrtV778!m0J3_Mu|E*W>L703@R{{H`D zC$-)z1D1iOiUBrzr*)@^C5?Z#Os-Dl+AiuBsH!Ntk&zmLMyKPDIvt06^bbSmE=(1s Yk;{>h=t24aKLptKzb4)`&ol%70vFk{H2?qr diff --git a/xformers/components/attention/.DS_Store b/xformers/components/attention/.DS_Store deleted file mode 100644 index 42fec3046492142695ed15a682d18f74110e1e99..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeI1&u-H|5XNUqq?)!0rBX#NNPdD!yg{hSfin+~<{!aos3Rvq@g)r0IR_d+i|1yh1wgB89AF z$ls+b_q5YQVjwY)7)T5x1`-4R0|R`swPhXOeSbHNiGjqxlVm{O4>?`sAjqAU_Ud4; z5&+p_+8Wo?0fxzg90a-Z(xBM%)q@LA7fKB0;oQ$v9XSYc=cR{>^KfxtWfv+GXRCwD zOO^M{779^kmH z-9vjtPA1KK8k2RsOg<9D7ksK>UO45eu7865%y`~_n1#G8Wz||onXBXa;7WE)MdUL^ z&37SRV{007b?m9kjdjh`0^}M?w^-@Ou`=iSHRfJ`8nND3rpTfF+L{+J-r21U^a=X< zl&Jp|R@D0lPMK}sGH|+6I`XT`HuR5JV@uA~*w=j$TLV@bn9O`y%O_++Uzu#^AKC6J zLqzi!#@PE-WA8csjPPyR<&`L7b;LPQ__$}FhKAG=!lgwvWfx3{GUBgsShsYBrN2*^ zo-%H+qEC@#qhMFWY$!){8=1YU3S#bjBQ*5cxTz%u9xVgUd}J3||G(=#|9`acr@#^e zi2**6z3UIxS1=K3>pf1i*3RiSbZx4)^U@#;R*oaB97o>&!;tG7RmBW~+7+o`$E*z#nE%bgKXW