Skip to content

Commit

Permalink
Fixed sparse attention docs warning
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jun 13, 2023
1 parent 2f2f2a5 commit 2662e25
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tfts/layers/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,16 @@ def __init__(self, hidden_size: int, num_heads: int, attention_dropout: float =
def build(self, input_shape):
super().build(input_shape)

def call(self, x):
def call(self, x, mask=None):
"""Sparse attention
Parameters
----------
x : tf.Tensor
_description_
mask : tf.Tensor, optional
_description_, by default None
"""
return

def get_config(self):
Expand All @@ -254,14 +263,14 @@ def __init__(self, **kwargs):
def build(self, input_shape):
super().build(input_shape)

def call(self, x, x_mask=None):
"""_summary_
def call(self, x, mask=None):
"""Fast attention
Parameters
----------
x : _type_
x : tf.Tensor
_description_
x_mask : _type_, optional
mask : _type_, optional
_description_, by default None
"""
return
Expand Down

0 comments on commit 2662e25

Please sign in to comment.