Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move all torch.ops.load calls to the __init__.py scripts #89

Merged
merged 9 commits into from
Mar 17, 2023
2 changes: 0 additions & 2 deletions src/pytorch/BatchedNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
# SOFTWARE.
#

import os
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
from typing import List, NamedTuple, Tuple, Union

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
batchedLinear = torch.ops.NNPOpsBatchedNN.BatchedLinear


Expand Down
6 changes: 1 addition & 5 deletions src/pytorch/CFConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@
# SOFTWARE.
#

import os.path
import torch
from torch import Tensor

from NNPOps.CFConvNeighbors import CFConvNeighbors

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))

class CFConv(torch.nn.Module):
"""
Optimized continious-filter convolution layer (CFConv)
Expand Down Expand Up @@ -84,4 +80,4 @@ def __init__(self, gaussianWidth: float, activation: str,

def forward(self, neighbors: CFConvNeighbors, positions: Tensor, input: Tensor) -> Tensor:

return CFConv.operation(self.holder, neighbors.holder, positions, input)
return CFConv.operation(self.holder, neighbors.holder, positions, input)
5 changes: 1 addition & 4 deletions src/pytorch/CFConvNeighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@
# SOFTWARE.
#

import os.path
import torch
from torch import Tensor

torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))

class CFConvNeighbors(torch.nn.Module):
"""
Optimized nearest-neighbor implementation for the continious-filter convolution (CFConf)
Expand All @@ -45,4 +42,4 @@ def __init__(self, cutoff: float) -> None:
@torch.jit.export
def build(self, positions: Tensor) -> None:

self.holder.build(positions)
self.holder.build(positions)
6 changes: 1 addition & 5 deletions src/pytorch/SymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@
# SOFTWARE.
#

import os.path
from typing import List, Optional, Tuple
import torch
from torch import Tensor

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))

Holder = torch.classes.NNPOpsANISymmetryFunctions.Holder
operation = torch.ops.NNPOpsANISymmetryFunctions.operation

Expand Down Expand Up @@ -124,4 +120,4 @@ def forward(self, species_positions: Tuple[Tensor, Tensor],
radial, angular = operation(self.holder, positions[0], cell)
features = torch.cat((radial, angular), dim=1).unsqueeze(0)

return species, features
return species, features
7 changes: 6 additions & 1 deletion src/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
'''
High-performance PyTorch operations for neural network potentials
'''
import os.path
import site
import torch
torch.ops.load_library(os.path.join(site.getsitepackages()[-1],"NNPOps", "libNNPOpsPyTorch.so"))
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
torch.classes.load_library(os.path.join(site.getsitepackages()[-1],"NNPOps", "libNNPOpsPyTorch.so"))

from NNPOps.OptimizedTorchANI import OptimizedTorchANI
from NNPOps.OptimizedTorchANI import OptimizedTorchANI
7 changes: 6 additions & 1 deletion src/pytorch/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
'''
Neighbor operations
'''
import site
import os
import torch

from NNPOps.neighbors.getNeighborPairs import getNeighborPairs
torch.ops.load_library(os.path.join(site.getsitepackages()[-1],"NNPOps", "libNNPOpsPyTorch.so"))
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved

from NNPOps.neighbors.getNeighborPairs import getNeighborPairs