Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.sparse.as_sparse_gradcheck decorator of gradcheck that allows gradcheck input function to receive and return sparse tensors #107150

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f1868f
Add torch.sparse.enable_sparse_support decorator of gradcheck that al…
pearu Aug 14, 2023
0455be2
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 14, 2023
04cb0fe
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 16, 2023
824f62b
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 16, 2023
cb2a49a
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 17, 2023
139c7e9
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
05c692c
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
3759b45
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
0639a30
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
8266e92
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
0da1529
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
1eaf97a
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
cbd2bb9
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 18, 2023
e717e39
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 21, 2023
9a180c8
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 22, 2023
731a4f9
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 22, 2023
5b3a18f
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 22, 2023
60c95c6
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 22, 2023
a12a18c
Update on "Add torch.sparse.enable_sparse_support decorator of gradch…
pearu Aug 22, 2023
88033a8
Update on "Add torch.sparse.as_sparse_gradcheck decorator of gradchec…
pearu Aug 23, 2023
abda295
Update on "Add torch.sparse.as_sparse_gradcheck decorator of gradchec…
pearu Aug 23, 2023
31e4b3d
Update on "Add torch.sparse.as_sparse_gradcheck decorator of gradchec…
pearu Aug 24, 2023
01de747
Update on "Add torch.sparse.as_sparse_gradcheck decorator of gradchec…
pearu Aug 24, 2023
88c6649
Update on "Add torch.sparse.as_sparse_gradcheck decorator of gradchec…
pearu Aug 25, 2023
418cbac
Update on "Add torch.sparse.as_sparse_gradcheck decorator of gradchec…
pearu Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,15 @@ To manage checking sparse tensor invariants, see:

sparse.check_sparse_tensor_invariants

To use sparse tensors with :func:`~torch.autograd.gradcheck` function,
see:

.. autosummary::
:toctree: generated
:nosignatures:

sparse.as_sparse_gradcheck

Unary functions
---------------

Expand Down
50 changes: 50 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5032,6 +5032,56 @@ def test_like_fns(self, layout, device, dtype, op):
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
result.shape, result.layout)

# TODO: @all_sparse_layouts('layout', include_strided=False)
@parametrize("layout", [subtest(torch.sparse_coo, name='SparseCOO'),
subtest(torch.sparse_csr, name='SparseCSR')])
@parametrize("masked", [subtest(False, name='nonmasked'), subtest(True, name='masked')])
@parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
def test_as_sparse_gradcheck(self, layout, masked, fast_mode):
gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)

def identity(x):
return x

if layout is torch.sparse_coo:
def values_mth(x):
# TODO: remove coalesced after gh-107097 is fixed.
return x.coalesce().values()
else:
values_mth = torch.Tensor.values

for func in (torch.Tensor.to_dense,
torch.Tensor.sum,
identity,
torch.Tensor.to_sparse,
values_mth,
):
if layout is torch.sparse_csr and func.__name__ == 'values':
# FIXME: RuntimeError: indices expected sparse
# coordinate tensor layout but got SparseCsr. Likely
# works when gh-107126 is fixed.
continue
for x in self.generate_simple_inputs(
layout,
dtype=torch.float64,
# TODO: fix gh-104868 to enable batched samples:
enable_batch=layout is not torch.sparse_csr,
enable_hybrid=not (
layout is torch.sparse_csr and (
# FIXME: RuntimeError: sparse_mask(): the
# number of sparse dimensions in `self`
# should match that of the `mask`. Got
# `self.sparse_dim() == 3` !=
# `mask.sparse_dim() == 2
func.__name__ == 'sum'
# FIXME: RuntimeError: expected
# col_indices to be a contiguous tensor
# per batch
or func.__name__ == 'to_sparse'
))):
gradcheck(func, x.requires_grad_(True), masked=masked, fast_mode=fast_mode)


# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')

Expand Down
176 changes: 176 additions & 0 deletions torch/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
'log_softmax',
'SparseSemiStructuredTensor',
'to_sparse_semi_structured',
'as_sparse_gradcheck',
]

addmm = _add_docstr(_sparse._sparse_addmm, r"""
Expand Down Expand Up @@ -495,3 +496,178 @@ def test_mth(*args, **kwargs):
return mth(*args, **kwargs)

return test_mth


def as_sparse_gradcheck(gradcheck):
"""Decorator for torch.autograd.gradcheck or its functools.partial
variants that extends the gradcheck function with support to input
functions that operate on or/and return sparse tensors.

The specified gradcheck function itself is guaranteed to operate
on strided tensors only.

For example:

>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
>>> gradcheck(lambda x: x.to_sparse_csr(), x)
True
"""

def gradcheck_with_sparse_support(func, inputs, **kwargs):
"""Same as :func:`torch.autograd.gradcheck` but with sparse tensors
inputs and outputs support.
"""
masked = masked_grad = kwargs.pop('masked', False)
sparse_layouts = {torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
STRIDED_REPRESENTATION = '__STRIDED_REPRESENTATION__'

def _convert_to_strided_representation(obj):
"""Convert a differentiable non-strided tensor to a representation
containing differentiable strided tensors only.
"""
if isinstance(obj, torch.Tensor) and obj.requires_grad:
d = dict(layout=obj.layout, shape=obj.shape, original=obj)
device = obj.device
if obj.layout is torch.sparse_coo:
obj = obj.coalesce()
indices, values = obj.indices(), obj.values()
d.update(is_coalesced=obj.is_coalesced())
if masked:
d.update(indices=indices)
return (STRIDED_REPRESENTATION, d, values.requires_grad_(True))
else:
# Materialize unspecified elements with zero values
full_obj = torch.ones(obj.shape, dtype=obj.dtype, device=obj.device).to_sparse(
layout=torch.sparse_coo, dense_dim=obj.dense_dim())
full_obj.values().sub_(1)
full_obj += obj
d.update(indices=full_obj.indices())
return (STRIDED_REPRESENTATION, d, full_obj.values().requires_grad_(True))
elif obj.layout is torch.sparse_csr:
compressed_indices = obj.crow_indices()
plain_indices = obj.col_indices()
values = obj.values()
indices_dtype = compressed_indices.dtype
batch_dim = compressed_indices.ndim - 1
if masked:
indices = torch._convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
d.update(
indices=indices, # TODO: eliminate after gh-107373
compressed_indices=compressed_indices,
plain_indices=plain_indices)
return (STRIDED_REPRESENTATION, d, values.requires_grad_(True))
else:
batch_shape = obj.shape[:batch_dim]
dense_shape = values.shape[batch_dim + 1:]
full_nnz = obj.shape[batch_dim:batch_dim + 2].numel()

tmp = torch.ones(obj.shape[:batch_dim + 2], dtype=torch.int8, device=device).to_sparse(layout=obj.layout)
full_compressed_indices = tmp.crow_indices().to(dtype=indices_dtype)
full_plain_indices = tmp.col_indices().to(dtype=indices_dtype)
full_compressed_indices.expand(*batch_shape, *full_compressed_indices.shape)
full_plain_indices.expand(*batch_shape, *full_plain_indices.shape)

full_values = torch.zeros((*batch_shape, full_nnz, *dense_shape), dtype=values.dtype, device=values.device)

if values.numel() > 0:
strides = torch.empty(obj.shape[batch_dim:batch_dim + 2]).stride()
if batch_dim > 0:
batch_compressed_indices = compressed_indices.view(-1, *compressed_indices.shape[batch_dim:])
batch_plain_indices = plain_indices.view(-1, *plain_indices.shape[batch_dim:])
batch_values = values.view(-1, *values.shape[batch_dim:])
batch_full_values = full_values.view(-1, *full_values.shape[batch_dim:])
for i in range(batch_shape.numel()):
# TODO: eliminate this for-loop after gh-104868 is fixed
indices = torch._convert_indices_from_csr_to_coo(
batch_compressed_indices[i], batch_plain_indices[i])
flatten_indices = (torch.tensor([strides], device=device, dtype=indices.dtype).T
* indices).sum(0)
batch_full_values[i][flatten_indices] = batch_values[i]
else:
indices = torch._convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
flatten_indices = (torch.tensor([strides], device=device, dtype=indices.dtype).T * indices).sum(0)
full_values[flatten_indices] = values

full_indices = torch.ones(obj.shape[:batch_dim + 2],
device=device, dtype=torch.int8).nonzero().to(dtype=torch.int64).T
d.update(
indices=full_indices, # TODO: eliminate full_indices after gh-107373 is fixed
compressed_indices=full_compressed_indices,
plain_indices=full_plain_indices)
return (STRIDED_REPRESENTATION, d, full_values.requires_grad_(True))
elif obj.layout in {torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
raise NotImplementedError(f'converstion of {obj.layout} tensor to strided representation')
else:
return obj
return obj

def _restore_from_strided_representation(d, values):
"""Restore a non-strided differentiable tensor from its strided
representation.
"""
if d['layout'] is torch.sparse_coo:
return torch.sparse_coo_tensor(d['indices'], values, size=d['shape'], is_coalesced=d['is_coalesced'])
elif d['layout'] in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
dense_dim = d['original'].dense_dim()
batch_dim = d['compressed_indices'].ndim - 1
if batch_dim == 0 and dense_dim > 0:
# TODO: remove this if-block after gh-107373 is fixed
r = torch.sparse_coo_tensor(
d['indices'], values, size=d['shape'], is_coalesced=True).to_sparse(layout=d['layout'])
# TODO: use to_sparse(..., dense_dim=dense_dim)
# and remove the assert below after gh-107451 is
# fixed.
assert r.dense_dim() == dense_dim, (r.dense_dim(), dense_dim)
return r
return torch.sparse_compressed_tensor(d['compressed_indices'], d['plain_indices'], values,
size=d['shape'], layout=d['layout'])
else:
raise ValueError(f'unsupported sparse layout: {d["layout"]}')

def convert_to_strided_representation(args):
if not isinstance(args, (list, tuple)):
args = args,
new_args = []
for a in args:
if isinstance(a, torch.Tensor) and a.requires_grad:
a_ = _convert_to_strided_representation(a)
if a_ is not a:
# strided representation needs to inserted to
# arguments list element-wise because
# gradcheck does not detect differentiable
# inputs from deep Python structures.
new_args.extend(a_)
continue
new_args.append(a)
return tuple(new_args)

def restore_from_strided_representation(args):
new_args = []
args = list(args)
while args:
a = args.pop(0)
if a == STRIDED_REPRESENTATION:
a = _restore_from_strided_representation(d=args.pop(0), values=args.pop(0))
new_args.append(a)
return tuple(new_args)

def func_wrapper(*args, **kwargs):
restored_args = restore_from_strided_representation(args)

# convert differentiable output sparse tensors to strided
# tensors:
outputs = func(*restored_args, **kwargs)

strided_outputs = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
strided_outputs = tuple((o.to_dense(masked_grad=masked_grad)
if isinstance(o, torch.Tensor) and o.requires_grad and o.layout in sparse_layouts else o)
for o in strided_outputs)
pearu marked this conversation as resolved.
Show resolved Hide resolved

return strided_outputs if isinstance(outputs, (list, tuple)) else strided_outputs[0]

args = (func_wrapper, convert_to_strided_representation(inputs))

return gradcheck(*args, **kwargs)

return gradcheck_with_sparse_support
Loading