Skip to content

Commit

Permalink
Fix gradient computation in to_symmetric (#327)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
rusty1s authored May 22, 2023
1 parent 20c3dd9 commit 578e6e0
Showing 1 changed file with 123 additions and 45 deletions.
168 changes: 123 additions & 45 deletions torch_sparse/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,15 @@ def from_edge_index(
is_sorted: bool = False,
trust_data: bool = False,
):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted, trust_data=trust_data)
return SparseTensor(
row=edge_index[0],
rowptr=None,
col=edge_index[1],
value=edge_attr,
sparse_sizes=sparse_sizes,
is_sorted=is_sorted,
trust_data=trust_data,
)

@classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
Expand All @@ -84,13 +90,22 @@ def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if has_value:
value = mat[row, col]

return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
return SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True,
trust_data=True,
)

@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
has_value: bool = True):
def from_torch_sparse_coo_tensor(
self,
mat: torch.Tensor,
has_value: bool = True,
):
mat = mat.coalesce()
index = mat._indices()
row, col = index[0], index[1]
Expand All @@ -99,27 +114,46 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
if has_value:
value = mat.values()

return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
return SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True,
trust_data=True,
)

@classmethod
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
has_value: bool = True):
def from_torch_sparse_csr_tensor(
self,
mat: torch.Tensor,
has_value: bool = True,
):
rowptr = mat.crow_indices()
col = mat.col_indices()

value: Optional[torch.Tensor] = None
if has_value:
value = mat.values()

return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
return SparseTensor(
row=None,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True,
trust_data=True,
)

@classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
dtype: Optional[int] = None, device: Optional[torch.device] = None,
def eye(self,
M: int,
N: Optional[int] = None,
has_value: bool = True,
dtype: Optional[int] = None,
device: Optional[torch.device] = None,
fill_cache: bool = False):

N = M if N is None else N
Expand Down Expand Up @@ -214,13 +248,19 @@ def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
def has_value(self) -> bool:
return self.storage.has_value()

def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
def set_value_(
self,
value: Optional[torch.Tensor],
layout: Optional[str] = None,
):
self.storage.set_value_(value, layout)
return self

def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
def set_value(
self,
value: Optional[torch.Tensor],
layout: Optional[str] = None,
):
return self.from_storage(self.storage.set_value(value, layout))

def sparse_sizes(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -275,13 +315,21 @@ def __eq__(self, other) -> bool:
# Utility functions #######################################################

def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
device=self.device())
value = torch.full(
(self.nnz(), ),
fill_value,
dtype=dtype,
device=self.device(),
)
return self.set_value_(value, layout='coo')

def fill_value(self, fill_value: float, dtype: Optional[int] = None):
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
device=self.device())
value = torch.full(
(self.nnz(), ),
fill_value,
dtype=dtype,
device=self.device(),
)
return self.set_value(value, layout='coo')

def sizes(self) -> List[int]:
Expand Down Expand Up @@ -373,8 +421,8 @@ def to_symmetric(self, reduce: str = "sum"):
value = torch.cat([value, value])[perm]
value = segment_csr(value, ptr, reduce=reduce)

new_row = torch.cat([row, col], dim=0, out=perm)[idx]
new_col = torch.cat([col, row], dim=0, out=perm)[idx]
new_row = torch.cat([row, col], dim=0)[idx]
new_col = torch.cat([col, row], dim=0)[idx]

out = SparseTensor(
row=new_row,
Expand Down Expand Up @@ -406,8 +454,11 @@ def requires_grad(self) -> bool:
else:
return False

def requires_grad_(self, requires_grad: bool = True,
dtype: Optional[int] = None):
def requires_grad_(
self,
requires_grad: bool = True,
dtype: Optional[int] = None,
):
if requires_grad and not self.has_value():
self.fill_value_(1., dtype)

Expand Down Expand Up @@ -478,21 +529,29 @@ def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
row, col, value = self.coo()

if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype,
device=self.device())
mat = torch.zeros(
self.sizes(),
dtype=value.dtype,
device=self.device(),
)
else:
mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())

if value is not None:
mat[row, col] = value
else:
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
device=mat.device)
mat[row, col] = torch.ones(
self.nnz(),
dtype=mat.dtype,
device=mat.device,
)

return mat

def to_torch_sparse_coo_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
self,
dtype: Optional[int] = None,
) -> torch.Tensor:
row, col, value = self.coo()
index = torch.stack([row, col], dim=0)

Expand All @@ -502,7 +561,9 @@ def to_torch_sparse_coo_tensor(
return torch.sparse_coo_tensor(index, value, self.sizes())

def to_torch_sparse_csr_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
self,
dtype: Optional[int] = None,
) -> torch.Tensor:
rowptr, col, value = self.csr()

if value is None:
Expand All @@ -511,7 +572,9 @@ def to_torch_sparse_csr_tensor(
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())

def to_torch_sparse_csc_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
self,
dtype: Optional[int] = None,
) -> torch.Tensor:
colptr, row, value = self.csc()

if value is None:
Expand Down Expand Up @@ -548,8 +611,11 @@ def cpu(self) -> SparseTensor:
return self.device_as(torch.tensor(0., device='cpu'))


def cuda(self, device: Optional[Union[int, str]] = None,
non_blocking: bool = False):
def cuda(
self,
device: Optional[Union[int, str]] = None,
non_blocking: bool = False,
):
return self.device_as(torch.tensor(0., device=device or 'cuda'))


Expand Down Expand Up @@ -654,17 +720,29 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2]

storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
)

return SparseTensor.from_storage(storage)


@torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
def to_scipy(
self: SparseTensor,
layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> ScipySparseMatrix:
assert self.dim() == 2
layout = get_layout(layout)

Expand Down

0 comments on commit 578e6e0

Please sign in to comment.