Skip to content

Commit

Permalink
Remove torch scatter dep (#88)
Browse files Browse the repository at this point in the history
* fea: remove torch scatter dep

* remove torch-scatter CI install, rm .devcontainer/devcontainer.json

* bump version when dropping scatter

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
CompRhys and janosh authored Sep 10, 2024
1 parent 3c6411a commit 5413c71
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 40 deletions.
13 changes: 0 additions & 13 deletions .devcontainer/devcontainer.json

This file was deleted.

1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ jobs:
- name: Install dependencies
run: |
pip install torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html
uv pip install .[test] --system
- name: Run Tests
Expand Down
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,7 @@ The aim of `aviary` is to contain multiple models for materials discovery under

## Installation

Aviary requires [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter). `pip install` it with

```sh
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html
```

Make sure you replace `2.2.1` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable.

Then install `aviary` from source with
Users can install `aviary` from source with

```sh
pip install -U git+https://github.com/CompRhys/aviary
Expand Down
6 changes: 3 additions & 3 deletions aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_add, scatter_mean

from aviary.core import BaseModelClass
from aviary.networks import SimpleNetwork
from aviary.scatter import scatter_reduce

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -114,7 +114,7 @@ def forward(
"""
atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)

crys_fea = scatter_mean(atom_fea, crystal_atom_idx, dim=0)
crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean")

# NOTE required to match the reference implementation
crys_fea = nn.functional.softplus(crys_fea)
Expand Down Expand Up @@ -236,7 +236,7 @@ def forward(

# take the elementwise product of the filter and core
nbr_msg = filter_fea * core_fea
nbr_summed = scatter_add(nbr_msg, self_idx, dim=0)
nbr_summed = scatter_reduce(nbr_msg, self_idx, dim=0, reduce="sum")

nbr_summed = self.bn2(nbr_summed)
return self.softplus2(atom_in_fea + nbr_summed)
76 changes: 76 additions & 0 deletions aviary/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch


def scatter_reduce(src, index, dim=-1, dim_size=None, reduce="sum"):
"""Performs a scatter-reduce operation on the input tensor.
This function scatters the elements from the source tensor (src) into a new tensor
of shape determined by dim_size along the specified dimension (dim), using the
given reduction method. It's compatible with autograd for gradient computation.
NOTE this function was written by Claude 3.5 Sonnet.
Args:
src (torch.Tensor): The source tensor.
index (torch.Tensor): The indices of elements to scatter. Must be 1D or have
the same number of dimensions as src.
dim (int, optional): The axis along which to index. Defaults to -1.
dim_size (int, optional): The size of the output tensor's dimension `dim`.
If None, it's inferred as index.max().item() + 1. Defaults to None.
reduce (str, optional): The reduction operation to perform.
Options: "sum", "mean", "amax", "max", "amin", "min", "prod".
Defaults to "sum".
Returns:
torch.Tensor: The output tensor after the scatter-reduce operation.
Raises:
ValueError: If an unsupported reduction method is specified.
RuntimeError: If index and src tensors are incompatible.
Example:
>>> src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
>>> index = torch.tensor([0, 1, 0, 1, 2])
>>> scatter_reduce(src, index, dim=0, reduce="sum")
tensor([4., 6., 5.])
"""
if dim_size is None:
dim_size = index.max().item() + 1

# Prepare the output tensor shape
shape = list(src.shape)
shape[dim] = dim_size

# Ensure index has the same number of dimensions as src
if index.dim() != src.dim():
if index.dim() != 1:
raise RuntimeError(
"Index tensor must be 1D or have the same number of dimensions "
f"as src tensor. {index.shape=} != {src.shape=}"
)
# Expand index to match src dimensions
repeat_shape = [1] * src.dim()
repeat_shape[dim] = src.size(dim)
index = index.view(-1, *[1] * (src.dim() - 1)).expand_as(src)

# Perform scatter_reduce operation
if reduce in ["sum", "mean"]:
out = torch.zeros(shape, dtype=src.dtype, device=src.device)
out = out.scatter_add(dim, index, src)
if reduce == "mean":
count = torch.zeros(shape, dtype=src.dtype, device=src.device)
count = count.scatter_add(dim, index, torch.ones_like(src))
out = out / (count + (count == 0).float()) # avoid division by zero
elif reduce in ["amax", "max"]:
out = torch.full(shape, float("-inf"), dtype=src.dtype, device=src.device)
out = torch.max(out, out.scatter(dim, index, src))
elif reduce in ["amin", "min"]:
out = torch.full(shape, float("inf"), dtype=src.dtype, device=src.device)
out = torch.min(out, out.scatter(dim, index, src))
elif reduce == "prod":
out = torch.ones(shape, dtype=src.dtype, device=src.device)
out = out.scatter(dim, index, src, reduce="multiply")
else:
raise ValueError(f"Unsupported reduction method: {reduce}")

return out
14 changes: 7 additions & 7 deletions aviary/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_add, scatter_max

from aviary.networks import SimpleNetwork
from aviary.scatter import scatter_reduce

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -38,12 +38,12 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate -= scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index]
gate = gate.exp()
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10
gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
return scatter_reduce(gate * x, index, dim=0, reduce="sum")

def __repr__(self) -> str:
gate_nn, message_nn = self.gate_nn, self.message_nn
Expand Down Expand Up @@ -78,12 +78,12 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate -= scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index]
gate = (weights**self.pow) * gate.exp()
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10
gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
return scatter_reduce(gate * x, index, dim=0, reduce="sum")

def __repr__(self) -> str:
pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn
Expand Down
6 changes: 4 additions & 2 deletions aviary/wren/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_mean

from aviary.core import BaseModelClass
from aviary.networks import ResidualNetwork, SimpleNetwork
from aviary.scatter import scatter_reduce
from aviary.segments import MessageLayer, WeightedAttentionPooling

if TYPE_CHECKING:
Expand Down Expand Up @@ -261,7 +261,9 @@ def forward(
for attnhead in self.cry_pool
]

return scatter_mean(torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0)
return scatter_reduce(
torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0, reduce="mean"
)

def __repr__(self) -> str:
return (
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/Roost.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"\n",
"print(f\"{TORCH_VERSION=}\")\n",
"\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html\n",
"!pip install -U git+https://github.com/CompRhys/aviary.git # install aviary\n",
"!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997"
]
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/Wren.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"\n",
"print(f\"{TORCH_VERSION=}\")\n",
"\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html\n",
"!pip install -U git+https://github.com/CompRhys/aviary.git # install aviary\n",
"!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997"
]
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "aviary"
version = "1.0.0"
version = "1.1.0"
description = "A collection of machine learning models for materials discovery"
authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }]
readme = "README.md"
Expand All @@ -27,7 +27,6 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Chemistry",
Expand All @@ -42,7 +41,6 @@ dependencies = [
"scikit_learn",
"tensorboard",
"torch",
"torch_scatter",
"tqdm",
"wandb",
]
Expand Down Expand Up @@ -114,6 +112,7 @@ ignore = [
"D105", # Missing docstring in magic method
"D205", # 1 blank line required between summary line and description
"E731", # Do not assign a lambda expression, use a def
"ISC001",
"PD901", # pandas-df-variable-name
"PLR", # pylint refactor
"PT006", # pytest-parametrize-names-wrong-type
Expand Down

0 comments on commit 5413c71

Please sign in to comment.