Skip to content

Commit

Permalink
Merge pull request #361 from rsokl/view-no-clear-grad
Browse files Browse the repository at this point in the history
view operations no longer null grads
  • Loading branch information
rsokl authored Mar 3, 2021
2 parents a60481b + 2e4dd6d commit 419bec5
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 82 deletions.
11 changes: 8 additions & 3 deletions src/mygrad/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mygrad.operation_base import Operation

__all__ = [
"collect_all_operations",
"collect_all_operations_and_clear_grads",
"ContextTracker",
"reduce_broadcast",
"SkipGradient",
Expand All @@ -32,12 +32,17 @@
T = TypeVar("T")


def collect_all_operations(t: "Tensor", seen: Set["WeakRef[Operation]"]):
def collect_all_operations_and_clear_grads(
t: "Tensor", seen: Set["WeakRef[Operation]"]
):
"""Recursively accumulates in `seen` all operations involved
in creating `t`.
`seen` is updated in-place
"""
t._view_grad = None
t._grad = None

if t.creator is None or t.constant:
return

Expand All @@ -49,7 +54,7 @@ def collect_all_operations(t: "Tensor", seen: Set["WeakRef[Operation]"]):
seen.add(c)

for t in t.creator.variables:
collect_all_operations(t, seen)
collect_all_operations_and_clear_grads(t, seen)


class WeakRef(Generic[T]):
Expand Down
92 changes: 52 additions & 40 deletions src/mygrad/tensor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import mygrad._utils.graph_tracking as _track
import mygrad._utils.lock_management as _mem
from mygrad._tensor_core_ops.indexing import GetItem, SetItem
from mygrad._utils import WeakRef, WeakRefIterable, collect_all_operations
from mygrad._utils import (
WeakRef,
WeakRefIterable,
collect_all_operations_and_clear_grads,
)
from mygrad.errors import DisconnectedView
from mygrad.linalg.ops import MatMul
from mygrad.math.arithmetic.ops import (
Expand Down Expand Up @@ -660,15 +664,14 @@ def grad(self) -> Optional[np.ndarray]:
if self._base is None:
return self._grad

if self._view_grad is not None:
if self._view_grad is not None and self._view_grad.base is self._base._grad:
# view grad has been computed already
return self._view_grad

if self._base._grad is None or self._creator is None:
# ``self`` had its graph, connecting it to its base, cleared.
# ``self._view_grad`` can't be computed without this info.
# Defer to ``self.grad`` so that the present tensor
return self._grad
return None

(view_parent,) = self._creator.variables

Expand Down Expand Up @@ -775,37 +778,24 @@ def _op(

_uniques_bases_then_arrs = ()

# cast all input-vars to tensors
if _track.TRACK_GRAPH:
# lock memory of array data and clear any tensor
# gradients
tensor_vars = tuple(
cls(var, constant=True, copy=False)
if not isinstance(var, Tensor)
else var.null_grad(_clear_view_info=True)
for var in input_vars
)
if _mem.MEM_GUARD:

_uniques_bases_then_arrs = WeakRefIterable(
_mem.lock_arr_writeability(x)
for x in _mem.unique_arrs_and_bases(tensor_vars)
)
tensor_vars = tuple(
cls(var, constant=True, copy=False) if not isinstance(var, Tensor) else var
for var in input_vars
)

else:
# operations are not being tracked - don't lock memory or null grads
tensor_vars = tuple(
cls(var, constant=True, copy=False)
if not isinstance(var, Tensor)
else var
for var in input_vars
# cast all input-vars to tensors
if _track.TRACK_GRAPH and _mem.MEM_GUARD:
# lock memory of array data
_uniques_bases_then_arrs = WeakRefIterable(
_mem.lock_arr_writeability(x)
for x in _mem.unique_arrs_and_bases(tensor_vars)
)

if op_args is None:
op_args = tuple()

if op_kwargs is None:
op_kwargs = dict()
op_kwargs = {}

f = Op()

Expand All @@ -830,13 +820,15 @@ def _op(
_base=None,
)

# Determine whether or not op was a view; if so, `base`
# points to parent Tensor
# points to parent tensor that op-output is a view of
base = None # type: Optional[Tensor]

# If output of op is a view - tracks the tensor var that is
# the parent of the view
parent_var: Optional[Tensor] = None

# Determine whether or not op was a view; if so, `base`
# points to parent Tensor
op_out_base = op_out.base
if f.can_return_view and op_out_base is not None:
vars_can_share_mem = (
Expand All @@ -853,11 +845,25 @@ def _op(
or (op_out_base is parent_data_base)
or (op_out is parent_data)
):
if parent_var._base is not None and parent_var._creator is None:
parent_var._base = None

base = parent_var if parent_var.base is None else parent_var.base
break
else:
parent_var = None

for v in input_vars:
if isinstance(v, Tensor):
# tensor's graph has been cleared, but its base lingers
if v._base is not None and v._creator is None:
v._base = None

if base is None:
# non-view ops clear grads
v._grad = None
v._view_grad = None

if base is not None:
# we need to be able to replay view-ops for doing in-place operations
# on graphs with views
Expand Down Expand Up @@ -985,39 +991,45 @@ def backward(self, grad: Optional[ArrayLike] = None):
self.clear_graph()
return

# don't set self._grad yet because there is a grad-clearing step that
# occurs during graph creation
if grad is not None:
# `self` is guaranteed to be a tensor of floats
# so we can simply cast `grad` to be the same dtype
self._grad = asarray(grad, dtype=self.dtype)
_grad = asarray(grad, dtype=self.dtype)

if self._grad.shape != self.shape:
if _grad.shape != self.shape:
try:
# See if grad can broadcast to `self`
# raises ValueError if not
self._grad = np.multiply(
_grad = np.multiply(
np.full_like(self.data, fill_value=1.0),
self._grad,
_grad,
dtype=self.dtype,
)
if self._grad.shape != self.shape:
if _grad.shape != self.shape:
# mutual broadcasting occurred
raise ValueError()
except ValueError:
raise ValueError(
f"`tensor.backward(grad)` was passed a gradient with an incompatible shape.\n"
f"`grad` must be broadcast-compatible with `tensor.shape={self.shape}`\n"
f"Got `grad.shape={self._grad.shape}`"
f"Got `grad.shape={_grad.shape}`"
)
else:
self._grad = np.full_like(self.data, fill_value=1.0)
_grad = np.full_like(self.data, fill_value=1.0)

if self.creator is not None:
graph = set() # type: Set[WeakRef[Operation]]

# stores a set of all the operation-instances that participate in
# the computational graph up to and including the present operation
collect_all_operations(self, seen=graph)
graph = set() # type: Set[WeakRef[Operation]]

# populates graph and clears all grads
collect_all_operations_and_clear_grads(self, seen=graph)
self._grad = _grad
self._backward(graph=graph)
else:
self._grad = _grad

self.clear_graph()

Expand Down
15 changes: 2 additions & 13 deletions tests/tensor_base/test_no_null_grad_semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,10 @@
from tests.custom_strategies import tensors


def view_op(x):
return x[...]


def std_op(x):
return +x


@pytest.mark.parametrize("func", [view_op, std_op])
@given(x=tensors(include_grad=True))
def test_involving_a_tensor_in_a_graph_nulls_its_gradient(
func: Callable[[Tensor], Tensor], x: Tensor
):
def test_involving_a_tensor_in_a_graph_nulls_its_gradient(x: Tensor):
assert x.grad is not None
func(x)
_ = +x
assert x.grad is None
assert x._ops is not None

Expand Down
111 changes: 85 additions & 26 deletions tests/test_view_semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,45 @@
from mygrad.errors import InvalidBackprop
from tests.custom_strategies import tensors
from tests.utils.stateful import clear_all_mem_locking_state
from tests.utils.wrappers import clears_mem_state


def test_simple_view_grad_reflects_base_grad():
@pytest.mark.parametrize("view_pre_or_post_backward", ("pre", "post"))
def test_simple_view_grad_reflects_base_grad(view_pre_or_post_backward: str):
base = mg.Tensor([1.0, 2.0, 3.0])
view = base[:2]
assert view.base is base

if view_pre_or_post_backward == "pre":
view = base[:2]
assert view.base is base

(base ** 2).backward()

if view_pre_or_post_backward == "post":
view = base[:2]
assert view.base is base

assert_array_equal(view.grad, base.grad[:2])
assert np.shares_memory(view.grad, base.grad)
assert view.grad.base is base.grad

base.null_grad()
assert base.grad is None
assert view.grad is None


def test_simple_view_grad_reflects_nulled_base_grad():
@pytest.mark.parametrize("view_pre_or_post_backward", ("pre", "post"))
def test_simple_view_grad_reflects_nulled_base_grad(view_pre_or_post_backward: str):
base = mg.Tensor([1.0, 2.0, 3.0])
view = base[:2]

if view_pre_or_post_backward == "pre":
view = base[:2]

(base ** 2).backward()

if view_pre_or_post_backward == "post":
view = base[:2]

assert view.grad is not None

# Involving base in new graph should null its gradient
# and this should be reflected in its views
_ = +base
Expand All @@ -48,35 +71,26 @@ def test_simple_view_becomes_disconnected_from_base_via_clear_graph():
assert view.grad is None


@pytest.mark.xfail(
reason=""
"This is a known/documented inconsistency in MyGrad's "
"view semantics. It would be expensive to propagate "
"information forward this aggressively, and it is almost "
"certainly the case that the 'fix' would lead to a "
"less-intuitive user experience."
)
@clears_mem_state
def test_known_disagreement_between_view_grad_and_base():
@pytest.mark.parametrize("view_pre_or_post_backward", ("pre", "post"))
def test_nulling_base_grad_reflects_in_view(view_pre_or_post_backward):
base = mg.Tensor([1.0, 2.0, 3.0])
view = base[:2]

if view_pre_or_post_backward == "pre":
view = base[...][:2]

(base ** 2).backward()

if view_pre_or_post_backward == "post":
view = base[...][:2]

# pulling on `view.grad` will set its gradient
_ = view.grad

# Involving base in new graph nulls its gradient
# and disconnects it from any of its views
+base

assert base.grad is None

# But this doesn't propagate to `view` because it
# would be expensive to do so
#
# Despite view's base being set, its grad doesn't
# reflect the (nulled) grad of its base
assert view.base is base
assert view.grad is None # This should fail!
assert view.grad is None


def test_simple_view_becomes_disconnected_from_base_via_clear_graph2():
Expand Down Expand Up @@ -410,3 +424,48 @@ def test_resuming_graph_after_backprop_through_view(
assert_allclose(view, 3 * np.arange(4.0)[-2:])
assert_allclose(base.grad, np.ones_like(base))
assert_allclose(view.grad, np.ones_like(view))


@given(num_additional_views=st.integers(0, 3))
def test_sequence_of_interactions_with_view_and_backprop(num_additional_views: int):
base = mg.arange(4.0)[...]
base.backward([-1.0, 2.0, 3.0, -4.0])

view = base[-2:]
for _ in range(num_additional_views):
view = view[...]

# view's grad should be accurate even if grad was
# formed post-backprop
assert_allclose(view.grad, base.grad[-2:])
assert np.shares_memory(view.grad, base.grad)

# backpropping through base should update the
# view's grad
(2 * base).backward(-1)
assert_allclose(base.grad, np.full_like(base, -2))
assert_allclose(view.grad, base.grad[-2:])
assert np.shares_memory(view.grad, base.grad)

# taking a view of the base should not null its grad
view = base[-2:]
assert_allclose(base.grad, np.full_like(base, -2))

# but backpropping from the view should clear the base's
# grad and reset it to reflect the newest derivative
view.backward([-1.0, 10.0])
assert_allclose(view.grad, np.array([-1.0, 10.0]))
assert_allclose(base.grad, np.array([0.0, 0, -1.0, 10.0]))
assert np.shares_memory(view.grad, base.grad)

# involving in a new op should null both of their gradients
_ = +base

assert base.grad is None
assert view.grad is None

# view should be disconnected from base
(2 * view).backward()
assert view.base is None
assert_allclose(view.grad, np.full_like(view, fill_value=2.0))
assert base.grad is None

0 comments on commit 419bec5

Please sign in to comment.