Skip to content

Commit

Permalink
Fix two warnings (#676)
Browse files Browse the repository at this point in the history
- torch.nn.functional.sigmoid is deprecated in favor of torch.sigmoid.
- Clip cosh input in sechsq to avoid overflow.
  • Loading branch information
danieldk authored May 23, 2022
1 parent 07a7dcf commit ab16559
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,8 @@ def erf(self, X: FloatsType) -> FloatsType:
return out

def sechsq(self, X: FloatsType) -> FloatsType:
# Avoid overflow in cosh. Clipping at |20| has an error of 1.7e-17.
X = self.xp.clip(X, -20.0, 20.0)
return (1 / self.xp.cosh(X)) ** 2

def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType:
Expand Down
2 changes: 1 addition & 1 deletion thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def torch_hard_swish_mobilenet(x):
return torch.nn.functional.hardswish(x)

def torch_sigmoid(x):
return torch.nn.functional.sigmoid(x)
return torch.sigmoid(x)

# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L37
def torch_gelu_approx(x):
Expand Down

0 comments on commit ab16559

Please sign in to comment.