Skip to content

Commit

Permalink
Check that Torch-verified activations obey inplace (#709)
Browse files Browse the repository at this point in the history
And fix some activations that do not obey the `inplace` kwarg.
  • Loading branch information
danieldk authored Jun 30, 2022
1 parent 2ef3f3a commit c7b0d67
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
10 changes: 5 additions & 5 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,15 +722,15 @@ def as_contig(self, data: ArrayT, dtype: Optional[DTypes] = None) -> ArrayT:
return self.xp.ascontiguousarray(data, **kwargs)

def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType:
# To prevent overflows and help with regularization/numerical stability
X = self.xp.clip(X, -20.0, 20.0)

if inplace:
# To prevent overflows and help with regularization/numerical stability
X = self.xp.clip(X, -20.0, 20.0, out=X)
self.xp.exp(-X, out=X)
X += 1.0 # type: ignore[assignment]
X **= -1.0 # type: ignore[assignment]
return cast(FloatsType, X)
else:
X = self.xp.clip(X, -20.0, 20.0)
return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X)))

def backprop_sigmoid(
Expand Down Expand Up @@ -909,15 +909,15 @@ def backprop_relu_k(
return self.backprop_clipped_linear(dY, X, max_val=n, inplace=inplace)

def hard_sigmoid(self, X: FloatsType, inplace: bool = False) -> FloatsType:
return self.clipped_linear(X, slope=0.2, offset=0.5)
return self.clipped_linear(X, slope=0.2, offset=0.5, inplace=inplace)

def backprop_hard_sigmoid(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
return self.backprop_clipped_linear(dY, X, slope=0.2, offset=0.5)

def hard_tanh(self, X: FloatsType, inplace: bool = False) -> FloatsType:
return self.clipped_linear(X, min_val=-1.0, max_val=1.0)
return self.clipped_linear(X, min_val=-1.0, max_val=1.0, inplace=inplace)

def backprop_hard_tanh(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
Expand Down
13 changes: 9 additions & 4 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,10 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func):
y_thinc = forward(x_thinc)
y.backward()
assert x_thinc.dtype == y_thinc.dtype
assert ops.xp.isclose(y_thinc, forward(x_thinc, inplace=True), atol=1e-06)
assert y_thinc is not x_thinc
y_think_inplace = forward(x_thinc, inplace=True)
assert y_think_inplace is x_thinc
assert ops.xp.isclose(y_thinc, y_think_inplace, atol=1e-06)
assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-06)
x_thinc = ops.asarray([x], dtype=dtype)
dY_thinc = ops.asarray([1.0], dtype=dtype)
Expand All @@ -1314,10 +1317,12 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func):
if params == {"dY", "X", "Y"}:
dx_thinc = backward(dY_thinc, Y=y_thinc, X=x_thinc)
assert dx_thinc.dtype == x_thinc.dtype
assert ops.xp.isclose(
dx_thinc,
backward(dY=dY_thinc_inplace, Y=y_thinc, X=x_thinc, inplace=True),
assert dx_thinc is not dY_thinc
dx_thinc_inplace = backward(
dY=dY_thinc_inplace, Y=y_thinc, X=x_thinc, inplace=True
)
assert dx_thinc_inplace is dY_thinc_inplace
assert ops.xp.isclose(dx_thinc, dx_thinc_inplace)
assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06)
elif params == {"Y", "dY"}:
dx_thinc = backward(dY_thinc, Y=y_thinc)
Expand Down

0 comments on commit c7b0d67

Please sign in to comment.