diff --git a/README.md b/README.md index f7fa6c36..c9cd5187 100644 --- a/README.md +++ b/README.md @@ -322,10 +322,10 @@ Let's define a `Trainer` instance, using for example of the already existing `GI ```python from deeprank2.trainer import Trainer -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork trainer = Trainer( - NaiveNetwork, + VanillaNetwork, dataset_train, dataset_val, dataset_test @@ -389,11 +389,11 @@ Finally, the `Trainer` instance can be defined and the new data can be tested: ```python from deeprank2.trainer import Trainer -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork from deeprank2.utils.exporters import HDF5OutputExporter trainer = Trainer( - NaiveNetwork, + VanillaNetwork, dataset_test = dataset_test, pretrained_model = "", output_exporters = [HDF5OutputExporter("")] diff --git a/deeprank2/neuralnets/cnn/model3d.py b/deeprank2/neuralnets/cnn/model3d.py index db90924c..126327a1 100644 --- a/deeprank2/neuralnets/cnn/model3d.py +++ b/deeprank2/neuralnets/cnn/model3d.py @@ -23,7 +23,16 @@ # ---------------------------------------------------------------------- -class CnnRegression(nn.Module): # noqa: D101 +class CnnRegression(nn.Module): + """Convolutional Neural Network architecture for regression. + + This type of network is used to predict a single scalar value of a continuous variable. + + Args: + num_features: Number of features in the input data. + box_shape: Shape of the input data. + """ + def __init__(self, num_features: int, box_shape: tuple[int]): super().__init__() @@ -76,7 +85,16 @@ def forward(self, data): # ---------------------------------------------------------------------- -class CnnClassification(nn.Module): # noqa: D101 +class CnnClassification(nn.Module): + """Convolutional Neural Network architecture for binary classification. + + This type of network is used to predict the class of an input data point. + + Args: + num_features: Number of features in the input data. + box_shape: Shape of the input data. + """ + def __init__(self, num_features, box_shape): super().__init__() diff --git a/deeprank2/neuralnets/gnn/alignmentnet.py b/deeprank2/neuralnets/gnn/alignmentnet.py index 077cd5dc..315beac6 100644 --- a/deeprank2/neuralnets/gnn/alignmentnet.py +++ b/deeprank2/neuralnets/gnn/alignmentnet.py @@ -6,7 +6,19 @@ __author__ = "Daniel-Tobias Rademaker" -class GNNLayer(nn.Module): # noqa: D101 +class GNNLayer(nn.Module): + """Custom-defined layer of a Graph Neural Network. + + Args: + nmb_edge_projection: Number of features in the edge projection. + nmb_hidden_attr: Number of features in the hidden attributes. + nmb_output_features: Number of output features. + message_vector_length: Length of the message vector. + nmb_mlp_neurons: Number of neurons in the MLP. + act_fn: Activation function. Defaults to nn.SiLU(). + is_last_layer: Whether this is the last layer of the GNN. Defaults to True. + """ + def __init__( self, nmb_edge_projection, @@ -104,7 +116,26 @@ def output(self, hidden_features, get_attention=True): return output -class SuperGNN(nn.Module): # noqa: D101 +class SuperGNN(nn.Module): + """SuperGNN is a class that defines multiple GNN layers. + + In particular, the `preproc_edge_mlp` and `preproc_node_mlp` are meant to + preprocess the edge and node attributes, respectively. + + The `modlist` is a list of GNNLayer objects. + + Args: + nm_edge_attr: Number of edge features. + nmb_node_attr: Number of node features. + nmb_hidden_attr: Number of hidden features. + nmb_mlp_neurons: Number of neurons in the MLP. + nmb_edge_projection: Number of edge projections. + nmb_gnn_layers: Number of GNN layers. + nmb_output_features: Number of output features. + message_vector_length: Length of the message vector. + act_fn: Activation function. Defaults to nn.SiLU(). + """ + def __init__( self, nmb_edge_attr, @@ -172,7 +203,25 @@ def run_through_network(self, edges, edge_attr, node_attr, with_output_attention return self.modlist[-1].output(node_attr, True) # (boolean-positional-value-in-call) -class AlignmentGNN(SuperGNN): # noqa: D101 +class AlignmentGNN(SuperGNN): + """Architecture based on multiple :class:`GNNLayer` layers, suited for both regression and classification tasks. + + It applies different layers to the nodes and edges of a graph (`preproc_edge_mlp` and `preproc_node_mlp`), + and then applies multiple GNN layers (`modlist`). + + Args: + nm_edge_attr: Number of edge features. + nmb_node_attr: Number of node features. + nmb_output_features: Number of output features. + nmb_hidden_attr: Number of hidden features. + message_vector_length: Length of the message vector. + nmb_mlp_neurons: Number of neurons in the MLP. + nmb_gnn_layers: Number of GNN layers. + nmb_edge_projection: Number of edge projections. + act_fn: Activation function. Defaults to nn.SiLU(). + + """ + def __init__( self, nmb_edge_attr, diff --git a/deeprank2/neuralnets/gnn/foutnet.py b/deeprank2/neuralnets/gnn/foutnet.py index 8e4c4fad..83b58555 100644 --- a/deeprank2/neuralnets/gnn/foutnet.py +++ b/deeprank2/neuralnets/gnn/foutnet.py @@ -13,14 +13,13 @@ class FoutLayer(nn.Module): """FoutLayer. - This layer is described by eq. (1) of - Protein Interface Predition using Graph Convolutional Network + This layer is described by eq. (1) of Protein Interface Predition using Graph Convolutional Network by Alex Fout et al. NIPS 2018. Args: in_channels: Size of each input sample. out_channels: Size of each output sample. - bias: If set to :obj:`False`, the layer will not learn an additive bias. Defaults to True. + bias: If set to `False`, the layer will not learn an additive bias. Defaults to True. """ def __init__(self, in_channels: int, out_channels: int, bias: bool = True): @@ -70,7 +69,17 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class FoutNet(nn.Module): # noqa: D101 +class FoutNet(nn.Module): + """Architecture based on the FoutLayer, suited for both regression and classification tasks. + + It also uses community pooling to reduce the number of nodes. + + Args: + input_shape: Size of each input sample. + output_shape: Size of each output sample. Defaults to 1. + input_shape_edge: Size of each input edge. Defaults to None. + """ + def __init__( self, input_shape, diff --git a/deeprank2/neuralnets/gnn/ginet.py b/deeprank2/neuralnets/gnn/ginet.py index ec0bdf0f..6bd98cf3 100644 --- a/deeprank2/neuralnets/gnn/ginet.py +++ b/deeprank2/neuralnets/gnn/ginet.py @@ -10,7 +10,16 @@ # ruff: noqa: ANN001, ANN201 -class GINetConvLayer(nn.Module): # noqa: D101 +class GINetConvLayer(nn.Module): + """GiNet convolutional layer for graph neural networks. + + Args: + in_channels: Number of input features. + out_channels: Number of output features. + number_edge_features: Number of edge features. Defaults to 1. + bias: If set to `False`, the layer will not learn an additive bias. Defaults to False. + """ + def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): super().__init__() @@ -54,10 +63,17 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class GINet(nn.Module): # noqa: D101 - # input_shape -> number of node input features - # output_shape -> number of output value per graph - # input_shape_edge -> number of edge input features +class GINet(nn.Module): + """Architecture based on the GiNet convolutional layer, suited for both regression and classification tasks. + + It uses community pooling to reduce the number of nodes. + + Args: + input_shape: Number of input features. + output_shape: Number of output value per graph. Defaults to 1. + input_shape_edge: Number of edge input features. Defaults to 1. + """ + def __init__(self, input_shape, output_shape=1, input_shape_edge=1): super().__init__() self.conv1 = GINetConvLayer(input_shape, 16, input_shape_edge) diff --git a/deeprank2/neuralnets/gnn/ginet_nocluster.py b/deeprank2/neuralnets/gnn/ginet_nocluster.py index a8ebd809..5b6dd2d0 100644 --- a/deeprank2/neuralnets/gnn/ginet_nocluster.py +++ b/deeprank2/neuralnets/gnn/ginet_nocluster.py @@ -7,7 +7,16 @@ # ruff: noqa: ANN001, ANN201 -class GINetConvLayer(nn.Module): # noqa: D101 +class GINetConvLayer(nn.Module): + """GiNet convolutional layer for graph neural networks. + + Args: + in_channels: Number of input features. + out_channels: Number of output features. + number_edge_features: Number of edge features. Defaults to 1. + bias: If set to `False`, the layer will not learn an additive bias. Defaults to False. + """ + def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): super().__init__() @@ -51,10 +60,15 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class GINet(nn.Module): # noqa: D101 - # input_shape -> number of node input features - # output_shape -> number of output value per graph - # input_shape_edge -> number of edge input features +class GINet(nn.Module): + """Architecture based on the GiNet convolutional layer, suited for both regression and classification tasks. + + Args: + input_shape: Number of input features. + output_shape: Number of output value per graph. Defaults to 1. + input_shape_edge: Number of edge input features. Defaults to 1. + """ + def __init__(self, input_shape, output_shape=1, input_shape_edge=1): super().__init__() self.conv1 = GINetConvLayer(input_shape, 16, input_shape_edge) diff --git a/deeprank2/neuralnets/gnn/sgat.py b/deeprank2/neuralnets/gnn/sgat.py index 10ca1730..1e5b7119 100644 --- a/deeprank2/neuralnets/gnn/sgat.py +++ b/deeprank2/neuralnets/gnn/sgat.py @@ -23,7 +23,7 @@ class SGraphAttentionLayer(nn.Module): Args: in_channels: Size of each input sample. out_channels: Size of each output sample. - bias: If set to :obj:`False`, the layer will not learn an additive bias. Defaults to True. + bias: If set to `False`, the layer will not learn an additive bias. Defaults to True. """ # noqa: D301 def __init__( @@ -87,7 +87,17 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class SGAT(nn.Module): # noqa:D101 +class SGAT(nn.Module): + """Simple graph attention network, suited for both regression and classification tasks. + + It uses two graph attention layers and a MLP to predict the output. + + Args: + input_shape: Size of each input sample. + output_shape: Size of each output sample. Defaults to 1. + input_shape_edge: Size of each input edge. Defaults to None. + """ + def __init__( self, input_shape, diff --git a/deeprank2/neuralnets/gnn/naive_gnn.py b/deeprank2/neuralnets/gnn/vanilla_gnn.py similarity index 73% rename from deeprank2/neuralnets/gnn/naive_gnn.py rename to deeprank2/neuralnets/gnn/vanilla_gnn.py index 8f936045..5b2fb00b 100644 --- a/deeprank2/neuralnets/gnn/naive_gnn.py +++ b/deeprank2/neuralnets/gnn/vanilla_gnn.py @@ -7,7 +7,14 @@ # ruff: noqa: ANN001, ANN201 -class NaiveConvolutionalLayer(nn.Module): # noqa: D101 +class VanillaConvolutionalLayer(nn.Module): + """Vanilla convolutional layer for graph neural networks. + + Args: + count_node_features: Number of node features. + count_edge_features: Number of edge features. + """ + def __init__(self, count_node_features, count_edge_features): super().__init__() message_size = 32 @@ -31,18 +38,21 @@ def forward(self, node_features, edge_node_indices, edge_features): return self._node_mlp(node_input) -class NaiveNetwork(nn.Module): # noqa: D101 - def __init__(self, input_shape: int, output_shape: int, input_shape_edge: int): - """NaiveNetwork. +class VanillaNetwork(nn.Module): + """Vanilla graph neural network architecture suited for both regression and classification tasks. - Args: - input_shape: Number of node input features. - output_shape: Number of output value per graph. - input_shape_edge: Number of edge input features. - """ + It uses two vanilla convolutional layers and a MLP to predict the output. + + Args: + input_shape: Number of node input features. + output_shape: Number of output value per graph. + input_shape_edge: Number of edge input features. + """ + + def __init__(self, input_shape: int, output_shape: int, input_shape_edge: int): super().__init__() - self._external1 = NaiveConvolutionalLayer(input_shape, input_shape_edge) - self._external2 = NaiveConvolutionalLayer(input_shape, input_shape_edge) + self._external1 = VanillaConvolutionalLayer(input_shape, input_shape_edge) + self._external2 = VanillaConvolutionalLayer(input_shape, input_shape_edge) hidden_size = 128 self._graph_mlp = nn.Sequential(nn.Linear(input_shape, hidden_size), nn.ReLU(), nn.Linear(hidden_size, output_shape)) diff --git a/docs/source/deeprank2.neuralnets.rst b/docs/source/deeprank2.neuralnets.rst index 1474056f..a228548d 100644 --- a/docs/source/deeprank2.neuralnets.rst +++ b/docs/source/deeprank2.neuralnets.rst @@ -41,10 +41,10 @@ deeprank2.neuralnets.gnn.ginet\_nocluster :undoc-members: :show-inheritance: -deeprank2.neuralnets.gnn.naive\_gnn +deeprank2.neuralnets.gnn.vanilla\_gnn ------------------------------------------ -.. automodule:: deeprank2.neuralnets.gnn.naive_gnn +.. automodule:: deeprank2.neuralnets.gnn.vanilla_gnn :members: :undoc-members: :show-inheritance: diff --git a/docs/source/getstarted.md b/docs/source/getstarted.md index 15abe354..cca87829 100644 --- a/docs/source/getstarted.md +++ b/docs/source/getstarted.md @@ -312,10 +312,10 @@ Let's define a `Trainer` instance, using for example of the already existing `GI ```python from deeprank2.trainer import Trainer -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork trainer = Trainer( - NaiveNetwork, + VanillaNetwork, dataset_train, dataset_val, dataset_test @@ -366,11 +366,11 @@ The user can specify a DeepRank2 exporter or a custom one in `output_exporters` ```python from deeprank2.trainer import Trainer -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork from deeprank2.utils.exporters import HDF5OutputExporter trainer = Trainer( - NaiveNetwork, + VanillaNetwork, dataset_train, dataset_val, dataset_test, @@ -452,11 +452,11 @@ Finally, the `Trainer` instance can be defined and the new data can be tested: ```python from deeprank2.trainer import Trainer -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork from deeprank2.utils.exporters import HDF5OutputExporter trainer = Trainer( - NaiveNetwork, + VanillaNetwork, dataset_test = dataset_test, pretrained_model = "", output_exporters = [HDF5OutputExporter("")] diff --git a/tests/test_integration.py b/tests/test_integration.py index b0fce68b..8cf4b5e6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -14,7 +14,7 @@ from deeprank2.domain import targetstorage as targets from deeprank2.neuralnets.cnn.model3d import CnnClassification from deeprank2.neuralnets.gnn.ginet import GINet -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork from deeprank2.query import ProteinProteinInterfaceQuery, QueryCollection from deeprank2.tools.target import compute_ppi_scores from deeprank2.trainer import Trainer @@ -273,7 +273,7 @@ def test_nan_loss_cases( train_source=dataset_train, ) - trainer = Trainer(NaiveNetwork, dataset_train, dataset_valid) + trainer = Trainer(VanillaNetwork, dataset_train, dataset_valid) optimizer = torch.optim.SGD lr = 10000 diff --git a/tests/test_set_lossfunction.py b/tests/test_set_lossfunction.py index 252fd326..e63c217a 100644 --- a/tests/test_set_lossfunction.py +++ b/tests/test_set_lossfunction.py @@ -9,7 +9,7 @@ from deeprank2.dataset import GraphDataset from deeprank2.domain import losstypes as losses from deeprank2.domain import targetstorage as targets -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork from deeprank2.trainer import Trainer hdf5_path = "tests/data/hdf5/test.hdf5" @@ -29,7 +29,7 @@ def base_test( trainer.train(nepoch=2, best_model=False, filename=model_path) return Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_test=trainer.dataset_train, pretrained_model=model_path, ) @@ -52,7 +52,7 @@ def test_classif_default(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) @@ -66,7 +66,7 @@ def test_classif_all(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) @@ -83,7 +83,7 @@ def test_classif_weighted(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, class_weights=True, ) @@ -100,7 +100,7 @@ def test_classif_invalid_weighted(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, class_weights=True, ) @@ -116,7 +116,7 @@ def test_classif_invalid_lossfunction(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) lossfunction = nn.MSELoss @@ -127,7 +127,7 @@ def test_classif_invalid_lossfunction(self) -> None: def test_classif_invalid_lossfunction_override(self) -> None: dataset = GraphDataset(hdf5_path, target=targets.BINARY) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) lossfunction = nn.MSELoss @@ -148,7 +148,7 @@ def test_regress_default(self) -> None: task="regress", ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) @@ -163,7 +163,7 @@ def test_regress_all(self) -> None: task="regress", ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) for f in losses.regression_losses: @@ -180,7 +180,7 @@ def test_regress_invalid_lossfunction(self) -> None: task="regress", ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) lossfunction = nn.CrossEntropyLoss @@ -195,7 +195,7 @@ def test_regress_invalid_lossfunction_override(self) -> None: task="regress", ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) lossfunction = nn.CrossEntropyLoss diff --git a/tests/test_trainer.py b/tests/test_trainer.py index fb462e7a..b57ffdd6 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -19,8 +19,8 @@ from deeprank2.neuralnets.cnn.model3d import CnnClassification, CnnRegression from deeprank2.neuralnets.gnn.foutnet import FoutNet from deeprank2.neuralnets.gnn.ginet import GINet -from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork from deeprank2.neuralnets.gnn.sgat import SGAT +from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork from deeprank2.trainer import Trainer, _divide_dataset from deeprank2.utils.exporters import HDF5OutputExporter, ScatterPlotExporter, TensorboardBinaryClassificationExporter @@ -262,7 +262,7 @@ def test_sgat(self) -> None: ) assert len(os.listdir(self.work_directory)) > 0 - def test_naive(self) -> None: + def test_vanilla(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -270,7 +270,7 @@ def test_naive(self) -> None: _model_base_test( self.save_path, - NaiveNetwork, + VanillaNetwork, "tests/data/hdf5/test.hdf5", "tests/data/hdf5/test.hdf5", "tests/data/hdf5/test.hdf5", @@ -332,7 +332,7 @@ def test_incompatible_no_pretrained_no_train(self) -> None: with pytest.raises(ValueError): Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_test=dataset, ) @@ -456,7 +456,7 @@ def test_optim(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) @@ -472,7 +472,7 @@ def test_optim(self) -> None: with warnings.catch_warnings(record=UserWarning): trainer.train(nepoch=3, best_model=False, filename=self.save_path) trainer_pretrained = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_test=dataset, pretrained_model=self.save_path, ) @@ -487,7 +487,7 @@ def test_default_optim(self) -> None: target=targets.BINARY, ) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_train=dataset, ) @@ -662,7 +662,7 @@ def test_train_method_no_train(self) -> None: dataset_test = GraphDataset(hdf5_path=test_data_graph, train_source=pretrained_model_graph) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_test=dataset_test, pretrained_model=pretrained_model_graph, ) @@ -692,7 +692,7 @@ def test_test_method_pretrained_model_on_dataset_with_target(self) -> None: dataset_test = GraphDataset(hdf5_path=test_data_graph, train_source=pretrained_model_graph) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_test=dataset_test, pretrained_model=pretrained_model_graph, output_exporters=[HDF5OutputExporter("./")], @@ -729,7 +729,7 @@ def test_test_method_pretrained_model_on_dataset_without_target(self) -> None: dataset_test = GraphDataset(hdf5_path=test_data_graph, train_source=pretrained_model_graph) trainer = Trainer( - neuralnet=NaiveNetwork, + neuralnet=VanillaNetwork, dataset_test=dataset_test, pretrained_model=pretrained_model_graph, output_exporters=[HDF5OutputExporter("./")], @@ -777,7 +777,7 @@ def test_graph_save_and_load_model(self) -> None: task=targets.CLASSIF, features_transform=features_transform, ) - trainer = Trainer(NaiveNetwork, dataset) + trainer = Trainer(VanillaNetwork, dataset) # during the training the model is saved trainer.train(nepoch=2, batch_size=2, filename=self.save_path) assert trainer.features_transform == features_transform diff --git a/tutorials/training.ipynb b/tutorials/training.ipynb index 401c3403..499c3f8e 100644 --- a/tutorials/training.ipynb +++ b/tutorials/training.ipynb @@ -86,7 +86,7 @@ "logging.basicConfig(level=logging.INFO)\n", "from deeprank2.dataset import GraphDataset, GridDataset\n", "from deeprank2.trainer import Trainer\n", - "from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork\n", + "from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork\n", "from deeprank2.neuralnets.cnn.model3d import CnnClassification\n", "from deeprank2.utils.exporters import HDF5OutputExporter\n", "import warnings\n", @@ -287,7 +287,7 @@ "source": [ "A few notes about `Trainer` parameters:\n", "\n", - "- `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. The `Trainer` class takes care of formatting the output shape according to the task. This tutorial uses a simple network, `NaiveNetwork` (implemented in `deeprank2.neuralnets.gnn.naive_gnn`). All GNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank-core/tree/main/deeprank2/neuralnets/gnn) and can be used for training or as a basis for implementing new ones.\n", + "- `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. The `Trainer` class takes care of formatting the output shape according to the task. This tutorial uses a simple network, `VanillaNetwork` (implemented in `deeprank2.neuralnets.gnn.vanilla_gnn`). All GNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank-core/tree/main/deeprank2/neuralnets/gnn) and can be used for training or as a basis for implementing new ones.\n", "- `class_weights` is used for classification tasks only and assigns class weights based on the training dataset content to account for any potential inbalance between the classes. In this case the dataset is balanced (50% 0 and 50% 1), so it is not necessary to use it. It defaults to False.\n", "- `cuda` and `ngpu` are used for indicating whether to use CUDA and how many GPUs. By default, CUDA is not used and `ngpu` is 0.\n", "- The user can specify a deeprank2 exporter or a custom one in `output_exporters` parameter, together with the path where to save the results. Exporters are used for storing predictions information collected later on during training and testing. Later the results saved by `HDF5OutputExporter` will be read and evaluated.\n" @@ -308,7 +308,7 @@ "outputs": [], "source": [ "trainer = Trainer(\n", - " neuralnet=NaiveNetwork,\n", + " neuralnet=VanillaNetwork,\n", " dataset_train=dataset_train,\n", " dataset_val=dataset_val,\n", " dataset_test=dataset_test,\n",