diff --git a/.vscode/settings.json b/.vscode/settings.json index a519af241..6b68b7366 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -19,5 +19,11 @@ // Format all files on save "editor.formatOnSave": true, - "editor.defaultFormatter": "esbenp.prettier-vscode" + "editor.defaultFormatter": "esbenp.prettier-vscode", + "[ignore]": { + "editor.defaultFormatter": "foxundermoon.shell-format" + }, + "[properties]": { + "editor.defaultFormatter": "foxundermoon.shell-format" + } } diff --git a/README.dev.md b/README.dev.md new file mode 100644 index 000000000..256c222ec --- /dev/null +++ b/README.dev.md @@ -0,0 +1,93 @@ +# `deeprank2` developer documentation + +If you're looking for user documentation, go [here](README.md). + +## Code editor + +We use [Visual Studio Code (VS Code)](https://code.visualstudio.com/) as code editor. +The VS Code settings for this project can be found in [.vscode](.vscode). +The settings will be automatically loaded and applied when you open the project with VS Code. +See [the guide](https://code.visualstudio.com/docs/getstarted/settings) for more info about workspace settings of VS Code. + +## Package setup + +After having followed the [installation instructions](https://github.com/DeepRank/deeprank2#installation) and installed all the dependencies of the package, the repository can be cloned and its editable version can be installed: + +```bash +git clone https://github.com/DeepRank/deeprank2 +cd deeprank2 +pip install -e .'[test]' +``` + +## Running the tests + +You can check that all components were installed correctly, using pytest. +The quick test should be sufficient to ensure that the software works, while the full test (a few minutes) will cover a much broader range of settings to ensure everything is correct. + +Run `pytest tests/test_integration.py` for the quick test or just `pytest` for the full test (expect a few minutes to run). + +## Test coverage + +In addition to just running the tests to see if they pass, they can be used for coverage statistics, i.e. to determine how much of the package's code is actually executed during tests. In an activated conda environment with the development tools installed, inside the package directory, run: + +```bash +coverage run -m pytest +``` + +This runs tests and stores the result in a `.coverage` file. To see the results on the command line, run: + +```bash +coverage report +``` + +`coverage` can also generate output in HTML and other formats; see `coverage help` for more information. + +## Linting and Formatting + +We use [ruff](https://docs.astral.sh/ruff/) for linting, sorting imports and formatting of python (notebook) files. The configurations of `ruff` are set in [pyproject.toml](pyproject.toml) file. + +If you are using VS code, please install and activate the [Ruff extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) to automatically format and check linting. + +Otherwise, please ensure check both linting (`ruff fix .`) and formatting (`ruff format .`) before requesting a review. + +We use [prettier](https://prettier.io/) for formatting most other files. If you are editing or adding non-python files and using VS code, the [Prettier extension](https://marketplace.visualstudio.com/items?itemName=esbenp.prettier-vscode) can be installed to auto-format these files as well. + +## Versioning + +Bumping the version across all files is done before creating a new package release, running `bump2version [part]` from command line after having installed [bump2version](https://pypi.org/project/bump2version/) on your local environment. Instead of `[part]`, type the part of the version to increase, e.g. minor. The settings in `.bumpversion.cfg` will take care of updating all the files containing version strings. + +## Branching workflow + +We use a [Git Flow](https://nvie.com/posts/a-successful-git-branching-model/)-inspired branching workflow for development. DeepRank2's repository is based on two main branches with infinite lifetime: + +- `main` — this branch contains production (stable) code. All development code is merged into `main` in sometime. +- `dev` — this branch contains pre-production code. When the features are finished then they are merged into `dev`. + +During the development cycle, three main supporting branches are used: + +- Feature branches - Branches that branch off from `dev` and must merge into `dev`: used to develop new features for the upcoming releases. +- Hotfix branches - Branches that branch off from `main` and must merge into `main` and `dev`: necessary to act immediately upon an undesired status of `main`. +- Release branches - Branches that branch off from `dev` and must merge into `main` and `dev`: support preparation of a new production release. They allow many minor bug to be fixed and preparation of meta-data for a release. + +### Development conventions + +- Branching + - When creating a new branch, please use the following convention: `__`. + - Always branch from `dev` branch, unless there is the need to fix an undesired status of `main`. See above for more details about the branching workflow adopted. +- Pull Requests + - When creating a pull request, please use the following convention: `: `. Example _types_ are `fix:`, `feat:`, `build:`, `chore:`, `ci:`, `docs:`, `style:`, `refactor:`, `perf:`, `test:`, and others based on the [Angular convention](https://github.com/angular/angular/blob/22b96b9/CONTRIBUTING.md#-commit-message-guidelines). + +## Making a release + +1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files such as the current one, fix minor bugs if necessary). +2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning). +3. Verify that the information in `CITATION.cff` is correct (update the release date), and that `.zenodo.json` contains equivalent data. +4. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests). +5. Go to https://github.com/DeepRank/deeprank2/releases and draft a new release; create a new tag for the release, generate release notes automatically and adjust them, and finally publish the release as latest. This will trigger [a GitHub action](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing the package on PyPi. + +## UML + +Code-base class diagrams updated on 02/11/2023, generated with https://www.gituml.com (save the images and open them in the browser for zooming). + +- Data processing classes and functions: +- ML pipeline classes and functions: diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 02a6ba4a0..76bb0281e 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -23,12 +23,16 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets -# ruff: noqa: PYI051 (redundant-literal-union), the literal is a special case, while the str is generic - _log = logging.getLogger(__name__) class DeeprankDataset(Dataset): + """Parent class of :class:`GridDataset` and :class:`GraphDataset`. + + This class inherits from :class:`torch_geometric.data.dataset.Dataset`. + More detailed information about the parameters can be found in :class:`GridDataset` and :class:`GraphDataset`. + """ + def __init__( self, hdf5_path: str | list[str], @@ -43,11 +47,6 @@ def __init__( root: str, check_integrity: bool, ): - """Parent class of :class:`GridDataset` and :class:`GraphDataset`. - - This class inherits from :class:`torch_geometric.data.dataset.Dataset`. - More detailed information about the parameters can be found in :class:`GridDataset` and :class:`GraphDataset`. - """ super().__init__(root) if isinstance(hdf5_path, str): @@ -57,7 +56,8 @@ def __init__( self.hdf5_paths = hdf5_path else: - raise TypeError(f"hdf5_path: unexpected type: {type(hdf5_path)}") + msg = f"hdf5_path: unexpected type: {type(hdf5_path)}" + raise TypeError(msg) self.subset = subset self.train_source = train_source @@ -85,11 +85,11 @@ def __init__( # get the device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def _check_and_inherit_train( + def _check_and_inherit_train( # noqa: C901 self, data_type: GridDataset | GraphDataset, - inherited_params, - ): + inherited_params: list[str], + ) -> None: """Check if the pre-trained model or training set provided are valid for validation and/or testing, and inherit the parameters.""" if isinstance(self.train_source, str): try: @@ -98,11 +98,12 @@ def _check_and_inherit_train( else: data = torch.load(self.train_source, map_location=torch.device("cpu")) if data["data_type"] is not data_type: - raise TypeError( + msg = ( f"The pre-trained model has been trained with data of type {data['data_type']}, but you are trying \n\t" f"to define a {data_type}-class validation/testing dataset. Please provide a valid DeepRank2 \n\t" f"model trained with {data_type}-class type data, or define the dataset using the appropriate class." ) + raise TypeError(msg) if data_type is GraphDataset: self.train_means = data["means"] self.train_devs = data["devs"] @@ -111,24 +112,31 @@ def _check_and_inherit_train( for key in data["features_transform"].values(): if key["transform"] is None: continue - key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 (suspicious-eval-usage) + key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 except pickle.UnpicklingError as e: - raise ValueError("The path provided to `train_source` is not a valid DeepRank2 pre-trained model.") from e + msg = "The path provided to `train_source` is not a valid DeepRank2 pre-trained model." + raise ValueError(msg) from e elif isinstance(self.train_source, data_type): data = self.train_source if data_type is GraphDataset: self.train_means = self.train_source.means self.train_devs = self.train_source.devs else: - raise TypeError( + msg = ( f"The train data provided is invalid: {type(self.train_source)}.\n\t" f"Please provide a valid training {data_type} or the path to a valid DeepRank2 pre-trained model." ) + raise TypeError( + msg, + ) + raise TypeError( + msg, + ) # match parameters with the ones in the training set self._check_inherited_params(inherited_params, data) - def _check_hdf5_files(self): + def _check_hdf5_files(self) -> None: """Checks if the data contained in the .HDF5 file is valid.""" _log.info("\nChecking dataset Integrity...") to_be_removed = [] @@ -139,7 +147,7 @@ def _check_hdf5_files(self): if len(entry_names) == 0: _log.info(f" -> {hdf5_path} is empty ") to_be_removed.append(hdf5_path) - except Exception as e: # noqa: BLE001, PERF203 (blind-except, try-except-in-loop) + except Exception as e: # noqa: BLE001, PERF203 _log.error(e) _log.info(f" -> {hdf5_path} is corrupted ") to_be_removed.append(hdf5_path) @@ -147,7 +155,7 @@ def _check_hdf5_files(self): for hdf5_path in to_be_removed: self.hdf5_paths.remove(hdf5_path) - def _check_task_and_classes(self, task: str, classes: str | None = None): + def _check_task_and_classes(self, task: str, classes: str | None = None) -> None: if self.target in [targets.IRMSD, targets.LRMSD, targets.FNAT, targets.DOCKQ]: self.task = targets.REGRESS @@ -158,11 +166,12 @@ def _check_task_and_classes(self, task: str, classes: str | None = None): self.task = task if self.task not in [targets.CLASSIF, targets.REGRESS] and self.target is not None: - raise ValueError(f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}") + msg = f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}" + raise ValueError(msg) if task != self.task and task is not None: warnings.warn( - f"Target {self.target} expects {self.task}, but was set to task {task} by user.\nUser set task is ignored and {self.task} will be used." + f"Target {self.target} expects {self.task}, but was set to task {task} by user.\nUser set task is ignored and {self.task} will be used.", ) if self.task == targets.CLASSIF: @@ -181,7 +190,7 @@ def _check_inherited_params( self, inherited_params: list[str], data: dict | GraphDataset | GridDataset, - ): + ) -> None: """Check if the parameters for validation and/or testing are the same as in the pre-trained model or training set provided. Args: @@ -199,11 +208,11 @@ def _check_inherited_params( _log.warning( f"The {param} parameter set here is: {self_vars[param]}, " f"which is not equivalent to the one in the training phase: {data[param]}./n" - f"Overwriting {param} parameter with the one used in the training phase." + f"Overwriting {param} parameter with the one used in the training phase.", ) setattr(self, param, data[param]) - def _create_index_entries(self): + def _create_index_entries(self) -> None: """Creates the indexing of each molecule in the dataset. Creates the indexing: [ ('1ak4.hdf5,1AK4_100w),...,('1fqj.hdf5,1FGJ_400w)]. @@ -237,7 +246,7 @@ def _create_index_entries(self): else: self.index_entries += [(hdf5_path, entry_name) for entry_name in entry_names if self._filter_targets(hdf5_file[entry_name])] - except Exception: # noqa: BLE001 (blind-except) + except Exception: # noqa: BLE001 _log.exception(f"on {hdf5_path}") def _filter_targets(self, grp: h5py.Group) -> bool: @@ -269,11 +278,12 @@ def _filter_targets(self, grp: h5py.Group) -> bool: for operator_string in [">", "<", "==", "<=", ">=", "!="]: operation = operation.replace(operator_string, f"{target_value}" + operator_string) - if not eval(operation): # noqa: S307, PGH001 (suspicious-eval-usage) + if not eval(operation): # noqa: S307, PGH001 return False elif target_condition is not None: - raise ValueError("Conditions not supported", target_condition) + msg = "Conditions not supported" + raise ValueError(msg, target_condition) else: _log.warning(f" :Filter {target_name} not found for entry {grp}\n :Filter options are: {present_target_names}") @@ -287,7 +297,7 @@ def len(self) -> int: """ return len(self.index_entries) - def hdf5_to_pandas( + def hdf5_to_pandas( # noqa: C901 self, ) -> pd.DataFrame: """Loads features data from the HDF5 files into a Pandas DataFrame in the attribute `df` of the class. @@ -319,7 +329,7 @@ def hdf5_to_pandas( if (transform is None) and (feat in self.features_transform): transform = self.features_transform.get(feat, {}).get("transform") # Check the number of channels the features have - if f[entry_name][feat_type][feat][()].ndim == 2: + if f[entry_name][feat_type][feat][()].ndim == 2: # noqa:PLR2004 for i in range(f[entry_name][feat_type][feat][:].shape[1]): df_dict[feat + "_" + str(i)] = [f[entry_name][feat_type][feat][:][:, i] for entry_name in entry_names] # apply transformation for each channel in this feature @@ -339,14 +349,14 @@ def hdf5_to_pandas( self.df = df_concat.reset_index(drop=True) return self.df - def save_hist( + def save_hist( # noqa: C901 self, features: str | list[str], fname: str = "features_hist.png", bins: int | list[float] | str = 10, figsize: tuple = (15, 15), log: bool = False, - ): + ) -> None: """After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image. Args: @@ -428,7 +438,8 @@ def save_hist( ) else: - raise ValueError("Please provide valid features names. They must be present in the current :class:`DeeprankDataset` children instance.") + msg = "Please provide valid features names. They must be present in the current :class:`DeeprankDataset` children instance." + raise ValueError(msg) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -436,7 +447,7 @@ def save_hist( fig.savefig(fname) plt.close(fig) - def _compute_mean_std(self): + def _compute_mean_std(self) -> None: means = { col: round(np.nanmean(np.concatenate(self.df[col].values)), 1) if isinstance(self.df[col].to_numpy()[0], np.ndarray) @@ -460,12 +471,58 @@ def _compute_mean_std(self): class GridDataset(DeeprankDataset): + """Class to load the .HDF5 files data into grids. + + Args: + hdf5_path (str | list): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. + subset (list[str] | None, optional): list of keys from .HDF5 file to include. Defaults to None (meaning include all). + train_source (str | class:`GridDataset` | None, optional): data to inherit information from the training dataset or the pre-trained model. + If None, the current dataset is considered as the training set. Otherwise, `train_source` needs to be a dataset of the same class or + the path of a DeepRank2 pre-trained model. If set, the parameters `features`, `target`, `traget_transform`, `task`, and `classes` + will be inherited from `train_source`. + Defaults to None. + features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed features ("all") or some defined node features + (provide a list, example: ["res_type", "polarity", "bsa"]). The complete list can be found in `deeprank2.domain.gridstorage`. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to "all". + target (str | None, optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be + a custom-defined target given to the Query class as input (see: `deeprank2.query`); in this case, + the task parameter needs to be explicitly specified as well. + Only numerical target variables are supported, not categorical. + If the latter is your case, please convert the categorical classes into + numerical class indices before defining the :class:`GraphDataset` instance. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + target_transform (bool, optional): Apply a log and then a sigmoid transformation to the target (for regression only). + This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to False. + target_filter (dict[str, str] | None, optional): Dictionary of type [target: cond] to filter the molecules. + Note that the you can filter on a different target than the one selected as the dataset target. + Defaults to None. + task (Literal["regress", "classif"] | None, optional): 'regress' for regression or 'classif' for classification. Required if target not in + ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. + Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. + Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + classes (list[str] | list[int] | list[float] | None): Define the dataset target classes in classification mode. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + use_tqdm (bool, optional): Show progress bar. + Defaults to True. + root (str, optional): Root directory where the dataset should be saved. + Defaults to "./". + check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. + Defaults to True. + """ + def __init__( self, hdf5_path: str | list, subset: list[str] | None = None, train_source: str | GridDataset | None = None, - features: list[str] | str | Literal["all"] | None = "all", + features: list[str] | str | None = "all", target: str | None = None, target_transform: bool = False, target_filter: dict[str, str] | None = None, @@ -475,51 +532,6 @@ def __init__( root: str = "./", check_integrity: bool = True, ): - """Class to load the .HDF5 files data into grids. - - Args: - hdf5_path (str | list): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. - subset (list[str] | None, optional): list of keys from .HDF5 file to include. Defaults to None (meaning include all). - train_source (str | class:`GridDataset` | None, optional): data to inherit information from the training dataset or the pre-trained model. - If None, the current dataset is considered as the training set. Otherwise, `train_source` needs to be a dataset of the same class or - the path of a DeepRank2 pre-trained model. If set, the parameters `features`, `target`, `traget_transform`, `task`, and `classes` - will be inherited from `train_source`. - Defaults to None. - features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed features ("all") or some defined node features - (provide a list, example: ["res_type", "polarity", "bsa"]). The complete list can be found in `deeprank2.domain.gridstorage`. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to "all". - target (str | None, optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be - a custom-defined target given to the Query class as input (see: `deeprank2.query`); in this case, - the task parameter needs to be explicitly specified as well. - Only numerical target variables are supported, not categorical. - If the latter is your case, please convert the categorical classes into - numerical class indices before defining the :class:`GraphDataset` instance. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - target_transform (bool, optional): Apply a log and then a sigmoid transformation to the target (for regression only). - This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to False. - target_filter (dict[str, str] | None, optional): Dictionary of type [target: cond] to filter the molecules. - Note that the you can filter on a different target than the one selected as the dataset target. - Defaults to None. - task (Literal["regress", "classif"] | None, optional): 'regress' for regression or 'classif' for classification. Required if target not in - ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. - Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. - Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - classes (list[str] | list[int] | list[float] | None): Define the dataset target classes in classification mode. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - use_tqdm (bool, optional): Show progress bar. - Defaults to True. - root (str, optional): Root directory where the dataset should be saved. - Defaults to "./". - check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. - Defaults to True. - """ super().__init__( hdf5_path, subset, @@ -557,14 +569,17 @@ def __init__( try: fname, mol = self.index_entries[0] except IndexError as e: - raise IndexError("No entries found in the dataset. Please check the dataset parameters.") from e + msg = "No entries found in the dataset. Please check the dataset parameters." + raise IndexError(msg) from e with h5py.File(fname, "r") as f5: grp = f5[mol] possible_targets = grp[targets.VALUES].keys() if self.target is None: - raise ValueError(f"Please set the target during training dataset definition; targets present in the file/s are {possible_targets}.") + msg = f"Please set the target during training dataset definition; targets present in the file/s are {possible_targets}." + raise ValueError(msg) if self.target not in possible_targets: - raise ValueError(f"Target {self.target} not present in the file/s; targets present in the file/s are {possible_targets}.") + msg = f"Target {self.target} not present in the file/s; targets present in the file/s are {possible_targets}." + raise ValueError(msg) self.features_dict = {} self.features_dict[gridstorage.MAPPED_FEATURES] = self.features @@ -574,7 +589,7 @@ def __init__( else: self.features_dict[targets.VALUES] = self.target - def _check_features(self): + def _check_features(self) -> None: # noqa: C901 """Checks if the required features exist.""" hdf5_path = self.hdf5_paths[0] @@ -630,13 +645,19 @@ def _check_features(self): # raise error if any features are missing if len(missing_features) > 0: - raise ValueError( + msg = ( f"Not all features could be found in the file {hdf5_path} under entry {mol_key}.\n\t" f"Missing features are: {missing_features}.\n\t" "Check feature_modules passed to the preprocess function.\n\t" "Probably, the feature wasn't generated during the preprocessing step.\n\t" f"Available features: {available_features}" ) + raise ValueError( + msg, + ) + raise ValueError( + msg, + ) def get(self, idx: int) -> Data: """Gets one grid item from its unique index. @@ -677,17 +698,22 @@ def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data: if self.task == targets.REGRESS and self.target_transform is True: y = torch.sigmoid(torch.log(y)) elif self.task is not targets.REGRESS and self.target_transform is True: - raise ValueError( - f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' - ) + msg = f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' + raise ValueError(msg) else: y = None possible_targets = grp[targets.VALUES].keys() if self.train_source is None: - raise ValueError( + msg = ( f"Target {self.target} missing in entry {entry_name} in file {hdf5_path}, possible targets are {possible_targets}.\n\t" "Use the query class to add more target values to input data." ) + raise ValueError( + msg, + ) + raise ValueError( + msg, + ) # Wrap up the data in this object, for the collate_fn to handle it properly: data = Data(x=x, y=y) @@ -697,13 +723,82 @@ def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data: class GraphDataset(DeeprankDataset): - def __init__( + """Class to load the .HDF5 files data into graphs. + + Args: + hdf5_path (str | list): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. + subset (list[str] | None, optional): list of keys from .HDF5 file to include. Defaults to None (meaning include all). + train_source (str | class:`GraphDataset` | None, optional): data to inherit information from the training dataset or the pre-trained model. + If None, the current dataset is considered as the training set. Otherwise, `train_source` needs to be a dataset of the same class or + the path of a DeepRank2 pre-trained model. If set, the parameters `node_features`, `edge_features`, `features_transform`, + `target`, `target_transform`, `task`, and `classes` will be inherited from `train_source`. If standardization was performed in the + training dataset/step, also the attributes `means` and `devs` will be inherited from `train_source`, and they will be used to scale + the validation/testing set. + Defaults to None. + node_features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed node features ("all") or + some defined node features (provide a list, example: ["res_type", "polarity", "bsa"]). + The complete list can be found in `deeprank2.domain.nodestorage`. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to "all". + edge_features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed edge features ("all") or + some defined edge features (provide a list, example: ["dist", "coulomb"]). + The complete list can be found in `deeprank2.domain.edgestorage`. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to "all". + features_transform (dict | None, optional): Dictionary to indicate the transformations to apply to each feature in the dictionary, being the + transformations lambda functions and/or standardization. + Example: `features_transform = {'bsa': {'transform': lambda t:np.log(t+1),' standardize': True}}` for the feature `bsa`. + An `all` key can be set in the dictionary for indicating to apply the same `standardize` and `transform` to all the features. + Example: `features_transform = {'all': {'transform': lambda t:np.log(t+1), 'standardize': True}}`. + If both `all` and feature name/s are present, the latter have the priority over what indicated in `all`. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + clustering_method (str | None, optional): "mcl" for Markov cluster algorithm (see https://micans.org/mcl/), + or "louvain" for Louvain method (see https://en.wikipedia.org/wiki/Louvain_method). + In both options, for each graph, the chosen method first finds communities (clusters) of nodes and generates + a torch tensor whose elements represent the cluster to which the node belongs to. Each tensor is then saved in + the .HDF5 file as a :class:`Dataset` called "depth_0". Then, all cluster members beloging to the same community are + pooled into a single node, and the resulting tensor is used to find communities among the pooled clusters. + The latter tensor is saved into the .HDF5 file as a :class:`Dataset` called "depth_1". Both "depth_0" and "depth_1" + :class:`Datasets` belong to the "cluster" Group. They are saved in the .HDF5 file to make them available to networks + that make use of clustering methods. Defaults to None. + target (str | None, optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. + It can also be a custom-defined target given to the Query class as input (see: `deeprank2.query`); + in this case, the task parameter needs to be explicitly specified as well. + Only numerical target variables are supported, not categorical. + If the latter is your case, please convert the categorical classes into + numerical class indices before defining the :class:`GraphDataset` instance. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + target_transform (bool, optional): Apply a log and then a sigmoid transformation to the target (for regression only). + This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to False. + target_filter (dict[str, str] | None, optional): Dictionary of type [target: cond] to filter the molecules. + Note that the you can filter on a different target than the one selected as the dataset target. + Defaults to None. + task (Literal["regress", "classif"] | None, optional): 'regress' for regression or 'classif' for classification. Required if target not in + ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. + Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. + Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + classes (list[str] | list[int] | list[float] | None): Define the dataset target classes in classification mode. + Value will be ignored and inherited from `train_source` if `train_source` is assigned. + Defaults to None. + use_tqdm (bool, optional): Show progress bar. Defaults to True. + root (str, optional): Root directory where the dataset should be saved. Defaults to "./". + check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. + Defaults to True. + """ + + def __init__( # noqa: C901 self, hdf5_path: str | list, subset: list[str] | None = None, train_source: str | GridDataset | None = None, - node_features: list[str] | str | Literal["all"] | None = "all", - edge_features: list[str] | str | Literal["all"] | None = "all", + node_features: list[str] | str | None = "all", + edge_features: list[str] | str | None = "all", features_transform: dict | None = None, clustering_method: str | None = None, target: str | None = None, @@ -715,74 +810,6 @@ def __init__( root: str = "./", check_integrity: bool = True, ): - """Class to load the .HDF5 files data into graphs. - - Args: - hdf5_path (str | list): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. - subset (list[str] | None, optional): list of keys from .HDF5 file to include. Defaults to None (meaning include all). - train_source (str | class:`GraphDataset` | None, optional): data to inherit information from the training dataset or the pre-trained model. - If None, the current dataset is considered as the training set. Otherwise, `train_source` needs to be a dataset of the same class or - the path of a DeepRank2 pre-trained model. If set, the parameters `node_features`, `edge_features`, `features_transform`, - `target`, `target_transform`, `task`, and `classes` will be inherited from `train_source`. If standardization was performed in the - training dataset/step, also the attributes `means` and `devs` will be inherited from `train_source`, and they will be used to scale - the validation/testing set. - Defaults to None. - node_features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed node features ("all") or - some defined node features (provide a list, example: ["res_type", "polarity", "bsa"]). - The complete list can be found in `deeprank2.domain.nodestorage`. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to "all". - edge_features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed edge features ("all") or - some defined edge features (provide a list, example: ["dist", "coulomb"]). - The complete list can be found in `deeprank2.domain.edgestorage`. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to "all". - features_transform (dict | None, optional): Dictionary to indicate the transformations to apply to each feature in the dictionary, being the - transformations lambda functions and/or standardization. - Example: `features_transform = {'bsa': {'transform': lambda t:np.log(t+1),' standardize': True}}` for the feature `bsa`. - An `all` key can be set in the dictionary for indicating to apply the same `standardize` and `transform` to all the features. - Example: `features_transform = {'all': {'transform': lambda t:np.log(t+1), 'standardize': True}}`. - If both `all` and feature name/s are present, the latter have the priority over what indicated in `all`. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - clustering_method (str | None, optional): "mcl" for Markov cluster algorithm (see https://micans.org/mcl/), - or "louvain" for Louvain method (see https://en.wikipedia.org/wiki/Louvain_method). - In both options, for each graph, the chosen method first finds communities (clusters) of nodes and generates - a torch tensor whose elements represent the cluster to which the node belongs to. Each tensor is then saved in - the .HDF5 file as a :class:`Dataset` called "depth_0". Then, all cluster members beloging to the same community are - pooled into a single node, and the resulting tensor is used to find communities among the pooled clusters. - The latter tensor is saved into the .HDF5 file as a :class:`Dataset` called "depth_1". Both "depth_0" and "depth_1" - :class:`Datasets` belong to the "cluster" Group. They are saved in the .HDF5 file to make them available to networks - that make use of clustering methods. Defaults to None. - target (str | None, optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. - It can also be a custom-defined target given to the Query class as input (see: `deeprank2.query`); - in this case, the task parameter needs to be explicitly specified as well. - Only numerical target variables are supported, not categorical. - If the latter is your case, please convert the categorical classes into - numerical class indices before defining the :class:`GraphDataset` instance. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - target_transform (bool, optional): Apply a log and then a sigmoid transformation to the target (for regression only). - This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to False. - target_filter (dict[str, str] | None, optional): Dictionary of type [target: cond] to filter the molecules. - Note that the you can filter on a different target than the one selected as the dataset target. - Defaults to None. - task (Literal["regress", "classif"] | None, optional): 'regress' for regression or 'classif' for classification. Required if target not in - ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. - Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. - Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - classes (list[str] | list[int] | list[float] | None): Define the dataset target classes in classification mode. - Value will be ignored and inherited from `train_source` if `train_source` is assigned. - Defaults to None. - use_tqdm (bool, optional): Show progress bar. Defaults to True. - root (str, optional): Root directory where the dataset should be saved. Defaults to "./". - check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. - Defaults to True. - """ super().__init__( hdf5_path, subset, @@ -826,14 +853,17 @@ def __init__( try: fname, mol = self.index_entries[0] except IndexError as e: - raise IndexError("No entries found in the dataset. Please check the dataset parameters.") from e + msg = "No entries found in the dataset. Please check the dataset parameters." + raise IndexError(msg) from e with h5py.File(fname, "r") as f5: grp = f5[mol] possible_targets = grp[targets.VALUES].keys() if self.target is None: - raise ValueError(f"Please set the target during training dataset definition; targets present in the file/s are {possible_targets}.") + msg = f"Please set the target during training dataset definition; targets present in the file/s are {possible_targets}." + raise ValueError(msg) if self.target not in possible_targets: - raise ValueError(f"Target {self.target} not present in the file/s; targets present in the file/s are {possible_targets}.") + msg = f"Target {self.target} not present in the file/s; targets present in the file/s are {possible_targets}." + raise ValueError(msg) self.features_dict = {} self.features_dict[Nfeat.NODE] = self.node_features @@ -869,7 +899,7 @@ def get(self, idx: int) -> Data: fname, mol = self.index_entries[idx] return self.load_one_graph(fname, mol) - def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 (too-many-statements) + def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915, C901 """Loads one graph. Args: @@ -908,10 +938,16 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 with warnings.catch_warnings(record=True) as w: vals = transform(vals) if len(w) > 0: - raise ValueError( + msg = ( f"Invalid value occurs in {entry_name}, file {fname},when applying {transform} for feature {feat}.\n\t" f"Please change the transformation function for {feat}." ) + raise ValueError( + msg, + ) + raise ValueError( + msg, + ) if vals.ndim == 1: # features with only one channel vals = vals.reshape(-1, 1) @@ -931,7 +967,7 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 # we have to have all the edges i.e : (i,j) and (j,i) if Efeat.INDEX in grp[Efeat.EDGE]: ind = grp[f"{Efeat.EDGE}/{Efeat.INDEX}"][()] - if ind.ndim == 2: + if ind.ndim == 2: # noqa: PLR2004 ind = np.vstack((ind, np.flip(ind, 1))).T edge_index = torch.tensor(ind, dtype=torch.long).contiguous() else: @@ -964,10 +1000,16 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 with warnings.catch_warnings(record=True) as w: vals = transform(vals) if len(w) > 0: - raise ValueError( + msg = ( f"Invalid value occurs in {entry_name}, file {fname}, when applying {transform} for feature {feat}.\n\t" f"Please change the transformation function for {feat}." ) + raise ValueError( + msg, + ) + raise ValueError( + msg, + ) if vals.ndim == 1: vals = vals.reshape(-1, 1) @@ -993,18 +1035,23 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 if self.task == targets.REGRESS and self.target_transform is True: y = torch.sigmoid(torch.log(y)) elif self.task is not targets.REGRESS and self.target_transform is True: - raise ValueError( - f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' - ) + msg = f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' + raise ValueError(msg) else: y = None possible_targets = grp[targets.VALUES].keys() if self.train_source is None: - raise ValueError( + msg = ( f"Target {self.target} missing in entry {entry_name} in file {fname}, possible targets are {possible_targets}.\n\t" "Use the query class to add more target values to input data." ) + raise ValueError( + msg, + ) + raise ValueError( + msg, + ) # positions pos = torch.tensor(grp[f"{Nfeat.NODE}/{Nfeat.POSITION}/"][()], dtype=torch.float).contiguous() @@ -1038,7 +1085,7 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 return data - def _check_features(self): + def _check_features(self) -> None: # noqa: C901 """Checks if the required features exist.""" f = h5py.File(self.hdf5_paths[0], "r") mol_key = next(iter(f.keys())) @@ -1090,7 +1137,7 @@ def _check_features(self): miss_node_error, miss_edge_error = "", "" _log.info( "\nCheck feature_modules passed to the preprocess function.\ - Probably, the feature wasn't generated during the preprocessing step." + Probably, the feature wasn't generated during the preprocessing step.", ) if missing_node_features: _log.info(f"\nAvailable node features: {self.available_node_features}\n") @@ -1100,15 +1147,26 @@ def _check_features(self): _log.info(f"\nAvailable edge features: {self.available_edge_features}\n") miss_edge_error = f"\nMissing edge features: {missing_edge_features} \ \nAvailable edge features: {self.available_edge_features}" - raise ValueError( + msg = ( f"Not all features could be found in the file {self.hdf5_paths[0]}.\n\t" "Check feature_modules passed to the preprocess function.\n\t" "Probably, the feature wasn't generated during the preprocessing step.\n\t" f"{miss_node_error}{miss_edge_error}" ) + raise ValueError( + msg, + ) + raise ValueError( + msg, + ) -def save_hdf5_keys(f_src_path: str, src_ids: list[str], f_dest_path: str, hardcopy=False): +def save_hdf5_keys( + f_src_path: str, + src_ids: list[str], + f_dest_path: str, + hardcopy: bool = False, +) -> None: """Save references to keys in src_ids in a new .HDF5 file. Args: @@ -1121,7 +1179,8 @@ def save_hdf5_keys(f_src_path: str, src_ids: list[str], f_dest_path: str, hardco Defaults to False. """ if not all(isinstance(d, str) for d in src_ids): - raise TypeError("data_ids should be a list containing strings.") + msg = "data_ids should be a list containing strings." + raise TypeError(msg) with h5py.File(f_dest_path, "w") as f_dest, h5py.File(f_src_path, "r") as f_src: for key in src_ids: diff --git a/deeprank2/domain/aminoacidlist.py b/deeprank2/domain/aminoacidlist.py index b49a96c77..36a799184 100644 --- a/deeprank2/domain/aminoacidlist.py +++ b/deeprank2/domain/aminoacidlist.py @@ -1,3 +1,5 @@ +from typing import Literal + from deeprank2.molstruct.aminoacid import AminoAcid, Polarity # All info below sourced from above websites in December 2022 and summarized in deeprank2/domain/aminoacid_summary.xlsx @@ -377,21 +379,42 @@ amino_acids_by_name = {amino_acid.name: amino_acid for amino_acid in amino_acids} -def convert_aa_nomenclature(aa: str, output_type: int | None = None): +def convert_aa_nomenclature(aa: str, output_format: Literal[0, 1, 3] = 0) -> str: + """Converts amino acid nomenclatures. + + Conversions are possible between the standard 1-letter codes, 3-letter + codes, and full names of amino acids. + + Args: + aa (str): The amino acid to be converted in any of its formats. The + length of the string is used to determine which format is used. + output_format (Literal[0, 1, 3], optional): Nomenclature style to return: + 0 (default) returns the full name, + 1 returns the 1-letter code, + 3 returns the 3-letter code. + + Raises: + ValueError: If aa is not recognized or an invalid output format was given + + Returns: + str: amino acid in the selected nomenclature system. + """ try: if len(aa) == 1: aa: AminoAcid = next(entry for entry in amino_acids if entry.one_letter_code.lower() == aa.lower()) - elif len(aa) == 3: + elif len(aa) == 3: # noqa:PLR2004 aa: AminoAcid = next(entry for entry in amino_acids if entry.three_letter_code.lower() == aa.lower()) else: aa: AminoAcid = next(entry for entry in amino_acids if entry.name.lower() == aa.lower()) except IndexError as e: - raise ValueError(f"{aa} is not a valid amino acid.") from e + msg = f"{aa} is not a valid amino acid." + raise ValueError(msg) from e - if not output_type: + if not output_format: return aa.name - if output_type == 3: + if output_format == 3: # noqa:PLR2004 return aa.three_letter_code - if output_type == 1: + if output_format == 1: return aa.one_letter_code - raise ValueError(f"output_type {output_type} not recognized. Must be set to None (amino acid name), 1 (one letter code), or 3 (three letter code).") + msg = f"output_format {output_format} not recognized. Must be set to 0 (amino acid name), 1 (one letter code), or 3 (three letter code)." + raise ValueError(msg) diff --git a/deeprank2/features/components.py b/deeprank2/features/components.py index d4844badb..3ebc21d35 100644 --- a/deeprank2/features/components.py +++ b/deeprank2/features/components.py @@ -11,11 +11,11 @@ _log = logging.getLogger(__name__) -def add_features( - pdb_path: str, # noqa: ARG001 (unused argument) +def add_features( # noqa:D103 + pdb_path: str, # noqa: ARG001 graph: Graph, single_amino_acid_variant: SingleResidueVariant | None = None, -): +) -> None: for node in graph.nodes: if isinstance(node.id, Residue): residue = node.id @@ -27,7 +27,8 @@ def add_features( node.features[Nfeat.PDBOCCUPANCY] = atom.occupancy node.features[Nfeat.ATOMCHARGE] = atomic_forcefield.get_charge(atom) else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) node.features[Nfeat.RESTYPE] = residue.amino_acid.onehot node.features[Nfeat.RESCHARGE] = residue.amino_acid.charge diff --git a/deeprank2/features/conservation.py b/deeprank2/features/conservation.py index 09e002eb9..66ffa4cf9 100644 --- a/deeprank2/features/conservation.py +++ b/deeprank2/features/conservation.py @@ -7,11 +7,11 @@ from deeprank2.utils.graph import Graph -def add_features( - pdb_path: str, # noqa: ARG001 (unused argument) +def add_features( # noqa:D103 + pdb_path: str, # noqa: ARG001 graph: Graph, single_amino_acid_variant: SingleResidueVariant | None = None, -): +) -> None: profile_amino_acid_order = sorted(amino_acids, key=lambda aa: aa.three_letter_code) for node in graph.nodes: @@ -21,7 +21,8 @@ def add_features( atom = node.id residue = atom.residue else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) pssm_row = residue.get_pssm() profile = np.array([pssm_row.get_conservation(amino_acid) for amino_acid in profile_amino_acid_order]) diff --git a/deeprank2/features/contact.py b/deeprank2/features/contact.py index 927c66208..47085da38 100644 --- a/deeprank2/features/contact.py +++ b/deeprank2/features/contact.py @@ -72,11 +72,11 @@ def _get_nonbonded_energy( return E_elec, E_vdw -def add_features( - pdb_path: str, # noqa: ARG001 (unused argument) +def add_features( # noqa:D103 + pdb_path: str, # noqa: ARG001 graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) -): + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 +) -> None: # assign each atoms (from all edges) a unique index all_atoms = set() if isinstance(graph.edges[0].id, AtomicContact): @@ -90,7 +90,8 @@ def add_features( for atom in contact.residue1.atoms + contact.residue2.atoms: all_atoms.add(atom) else: - raise TypeError(f"Unexpected edge type: {type(graph.edges[0].id)}") + msg = f"Unexpected edge type: {type(graph.edges[0].id)}" + raise TypeError(msg) all_atoms = list(all_atoms) atom_dict = {atom: i for i, atom in enumerate(all_atoms)} diff --git a/deeprank2/features/exposure.py b/deeprank2/features/exposure.py index a48a6b428..88d512ad7 100644 --- a/deeprank2/features/exposure.py +++ b/deeprank2/features/exposure.py @@ -2,6 +2,7 @@ import signal import sys import warnings +from typing import NoReturn import numpy as np from Bio.PDB.Atom import PDBConstructionWarning @@ -17,26 +18,27 @@ _log = logging.getLogger(__name__) -def handle_sigint(sig, frame): # noqa: ARG001 (unused argument) - print("SIGINT received, terminating.") +def handle_sigint(sig, frame) -> None: # noqa: ARG001, ANN001, D103 + _log.info("SIGINT received, terminating.") sys.exit() -def handle_timeout(sig, frame): # noqa: ARG001 (unused argument) - raise TimeoutError("Timed out!") +def handle_timeout(sig, frame) -> NoReturn: # noqa: ARG001, ANN001, D103 + msg = "Timed out!" + raise TimeoutError(msg) -def space_if_none(value): +def space_if_none(value: str) -> str: # noqa:D103 if value is None: return " " return value -def add_features( +def add_features( # noqa:D103 pdb_path: str, graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) -): + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 +) -> None: signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGALRM, handle_timeout) @@ -50,7 +52,8 @@ def add_features( surface = get_surface(bio_model) signal.alarm(0) except TimeoutError as e: - raise TimeoutError("Bio.PDB.ResidueDepth.get_surface timed out.") from e + msg = "Bio.PDB.ResidueDepth.get_surface timed out." + raise TimeoutError(msg) from e # These can only be calculated per residue, not per atom. # So for atomic graphs, every atom gets its residue's value. @@ -62,7 +65,8 @@ def add_features( atom = node.id residue = atom.residue else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) bio_residue = bio_model[residue.chain.id][residue.number] node.features[Nfeat.RESDEPTH] = residue_depth(bio_residue, surface) diff --git a/deeprank2/features/irc.py b/deeprank2/features/irc.py index 6f3ddc8dd..f1d0f4c07 100644 --- a/deeprank2/features/irc.py +++ b/deeprank2/features/irc.py @@ -11,6 +11,7 @@ from deeprank2.utils.graph import Graph _log = logging.getLogger(__name__) +SAFE_MIN_CONTACTS = 5 def _id_from_residue(residue: tuple[str, int, str]) -> str: @@ -100,11 +101,11 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, return residue_contacts -def add_features( +def add_features( # noqa: C901, D103 pdb_path: str, graph: Graph, single_amino_acid_variant: SingleResidueVariant | None = None, -): +) -> None: if not single_amino_acid_variant: # VariantQueries do not use this feature polarity_pairs = list(combinations(Polarity, 2)) polarity_pair_string = [f"irc_{x[0].name.lower()}_{x[1].name.lower()}" for x in polarity_pairs] @@ -119,7 +120,8 @@ def add_features( atom = node.id residue = atom.residue else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) contact_id = residue.chain.id + residue.number_string # reformat id to be in line with residue_contacts keys @@ -139,5 +141,5 @@ def add_features( except KeyError: # node has no contact residues and all counts remain 0 pass - if total_contacts < 5: + if total_contacts < SAFE_MIN_CONTACTS: _log.warning(f"Few ({total_contacts}) contacts detected for {pdb_path}.") diff --git a/deeprank2/features/secondary_structure.py b/deeprank2/features/secondary_structure.py index d3a2182bf..a42267ef1 100644 --- a/deeprank2/features/secondary_structure.py +++ b/deeprank2/features/secondary_structure.py @@ -4,6 +4,7 @@ import numpy as np from Bio.PDB import PDBParser from Bio.PDB.DSSP import DSSP +from numpy.typing import NDArray from deeprank2.domain import nodestorage as Nfeat from deeprank2.molstruct.atom import Atom @@ -23,20 +24,20 @@ class SecondarySctructure(Enum): COIL = 2 # ' -STP' @property - def onehot(self): + def onehot(self) -> NDArray: t = np.zeros(3) t[self.value] = 1.0 return t -def _get_records(lines: list[str]): +def _get_records(lines: list[str]) -> list[str]: seen = set() seen_add = seen.add return [x.split()[0] for x in lines if not (x in seen or seen_add(x))] -def _check_pdb(pdb_path: str): +def _check_pdb(pdb_path: str) -> None: fix_pdb = False with open(pdb_path, encoding="utf-8") as f: lines = f.readlines() @@ -71,7 +72,7 @@ def _check_pdb(pdb_path: str): f.writelines(lines) -def _classify_secstructure(subtype: str): +def _classify_secstructure(subtype: str) -> SecondarySctructure: if subtype in "GHI": return SecondarySctructure.HELIX if subtype in "BE": @@ -97,13 +98,16 @@ def _get_secstructure(pdb_path: str) -> dict: try: dssp = DSSP(model, pdb_path, dssp="mkdssp") - except Exception as e: # noqa: BLE001 (blind-except), namely: # improperly formatted pdb files raise: `Exception: DSSP failed to produce an output` + except Exception as e: # noqa: BLE001, namely: # improperly formatted pdb files raise: `Exception: DSSP failed to produce an output` pdb_format_link = "https://www.wwpdb.org/documentation/file-format-content/format33/sect1.html#Order" - raise DSSPError( + msg = ( f"DSSP has raised the following exception: {e}.\n\t" f"This is likely due to an improrperly formatted pdb file: {pdb_path}.\n\t" f"See {pdb_format_link} for guidance on how to format your pdb files.\n\t" "Alternatively, turn off secondary_structure feature module during QueryCollection.process()." + ) + raise DSSPError( + msg, ) from e chain_ids = [dssp_key[0] for dssp_key in dssp.property_keys] @@ -120,11 +124,11 @@ def _get_secstructure(pdb_path: str) -> dict: return sec_structure_dict -def add_features( +def add_features( # noqa:D103 pdb_path: str, graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) -): + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 +) -> None: sec_structure_features = _get_secstructure(pdb_path) for node in graph.nodes: @@ -134,7 +138,8 @@ def add_features( atom = node.id residue = atom.residue else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) chain_id = residue.chain.id res_num = residue.number @@ -142,6 +147,7 @@ def add_features( try: node.features[Nfeat.SECSTRUCT] = _classify_secstructure(sec_structure_features[chain_id][res_num]).onehot except AttributeError as e: + msg = f"Unknown secondary structure type ({sec_structure_features[chain_id][res_num]}) detected on chain {chain_id} residues {res_num}." raise ValueError( - f"Unknown secondary structure type ({sec_structure_features[chain_id][res_num]}) detected on chain {chain_id} residues {res_num}." + msg, ) from e diff --git a/deeprank2/features/surfacearea.py b/deeprank2/features/surfacearea.py index 67cc64920..afffca367 100644 --- a/deeprank2/features/surfacearea.py +++ b/deeprank2/features/surfacearea.py @@ -12,7 +12,7 @@ logging.getLogger(__name__) -def add_sasa(pdb_path: str, graph: Graph): +def add_sasa(pdb_path: str, graph: Graph) -> None: # noqa:D103 structure = freesasa.Structure(pdb_path) result = freesasa.calc(structure) @@ -29,14 +29,16 @@ def add_sasa(pdb_path: str, graph: Graph): area = freesasa.selectArea(selection, structure, result)["atom"] else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) if np.isnan(area): - raise ValueError(f"freesasa returned {area} for {residue}") + msg = f"freesasa returned {area} for {residue}" + raise ValueError(msg) node.features[Nfeat.SASA] = area -def add_bsa(graph: Graph): +def add_bsa(graph: Graph) -> None: # noqa:D103 sasa_complete_structure = freesasa.Structure() sasa_chain_structures = {} @@ -96,7 +98,8 @@ def add_bsa(graph: Graph): area_key = "atom" selection = f"atom, (name {atom.name}) and (resi {atom.residue.number_string}) and (chain {atom.residue.chain.id})" else: - raise TypeError(f"Unexpected node type: {type(node.id)}") + msg = f"Unexpected node type: {type(node.id)}" + raise TypeError(msg) sasa_complete_result = freesasa.calc(sasa_complete_structure) sasa_chain_results = {chain_id: freesasa.calc(structure) for chain_id, structure in sasa_chain_structures.items()} @@ -123,8 +126,8 @@ def add_bsa(graph: Graph): def add_features( pdb_path: str, graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) -): + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 +) -> None: """Calculates the Buried Surface Area (BSA) and the Solvent Accessible Surface Area (SASA).""" # BSA add_bsa(graph) diff --git a/deeprank2/molstruct/aminoacid.py b/deeprank2/molstruct/aminoacid.py index 8400f9c70..2a9f33c4c 100644 --- a/deeprank2/molstruct/aminoacid.py +++ b/deeprank2/molstruct/aminoacid.py @@ -2,6 +2,7 @@ import numpy as np from numpy.typing import NDArray +from typing_extensions import Self class Polarity(Enum): @@ -13,13 +14,29 @@ class Polarity(Enum): POSITIVE = 3 @property - def onehot(self): + def onehot(self) -> NDArray: t = np.zeros(4) t[self.value] = 1.0 return t class AminoAcid: + """An amino acid represents the type of `Residue` in a `PDBStructure`. + + Args: + name (str): Full name of the amino acid. + three_letter_code (str): Three-letter code of the amino acid (as in PDB). + one_letter_code (str): One-letter of the amino acid (as in fasta). + charge (int): Charge of the amino acid. + polarity (:class:`Polarity`): The polarity of the amino acid. + size (int): The number of non-hydrogen atoms in the side chain. + mass (float): Average residue mass (i.e. mass of amino acid - H20) in Daltons. + pI (float): Isolectric point; pH at which the molecule has no net electric charge. + hydrogen_bond_donors (int): Number of hydrogen bond donors. + hydrogen_bond_acceptors (int): Number of hydrogen bond acceptors. + index (int): The rank of the amino acid, used for computing one-hot encoding. + """ + def __init__( self, name: str, @@ -34,21 +51,6 @@ def __init__( hydrogen_bond_acceptors: int, index: int, ): - """An amino acid represents the type of `Residue` in a `PDBStructure`. - - Args: - name (str): Full name of the amino acid. - three_letter_code (str): Three-letter code of the amino acid (as in PDB). - one_letter_code (str): One-letter of the amino acid (as in fasta). - charge (int): Charge of the amino acid. - polarity (:class:`Polarity`): The polarity of the amino acid. - size (int): The number of non-hydrogen atoms in the side chain. - mass (float): Average residue mass (i.e. mass of amino acid - H20) in Daltons. - pI (float): Isolectric point; pH at which the molecule has no net electric charge. - hydrogen_bond_donors (int): Number of hydrogen bond donors. - hydrogen_bond_acceptors (int): Number of hydrogen bond acceptors. - index (int): The rank of the amino acid, used for computing one-hot encoding. - """ # amino acid nomenclature self._name = name self._three_letter_code = three_letter_code @@ -109,7 +111,8 @@ def hydrogen_bond_acceptors(self) -> int: @property def onehot(self) -> NDArray: if self._index is None: - raise ValueError(f"Amino acid {self._name} index is not set, thus no onehot can be computed.") + msg = f"Amino acid {self._name} index is not set, thus no onehot can be computed." + raise ValueError(msg) # 20 canonical amino acids # selenocysteine and pyrrolysine are indexed as cysteine and lysine, respectively a = np.zeros(20) @@ -123,7 +126,7 @@ def index(self) -> int: def __hash__(self) -> hash: return hash(self.name) - def __eq__(self, other) -> bool: + def __eq__(self, other: Self) -> bool: if isinstance(other, AminoAcid): return other.name == self.name return NotImplemented diff --git a/deeprank2/molstruct/atom.py b/deeprank2/molstruct/atom.py index 7088ee4b4..d26f0cc13 100644 --- a/deeprank2/molstruct/atom.py +++ b/deeprank2/molstruct/atom.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import numpy as np +from typing_extensions import Self if TYPE_CHECKING: from numpy.typing import NDArray @@ -15,7 +16,7 @@ class AtomicElement(Enum): """One-hot encoding of the atomic element (or atom type).""" C = 1 - O = 2 # noqa: E741 (ambiguous-variable-name) + O = 2 # noqa: E741 N = 3 S = 4 P = 5 @@ -29,6 +30,19 @@ def onehot(self) -> np.array: class Atom: + """One atom in a PDBStructure. + + Args: + residue (:class:`Residue`): The residue that this atom belongs to. + name (str): Pdb atom name. + element (:class:`AtomicElement`): The chemical element. + position (np.array): Pdb position xyz of this atom. + occupancy (float): Pdb occupancy value. + This represents the proportion of structures where the atom is detected at a given position. + Sometimes a single atom can be detected at multiple positions. In that case separate structures exist where sum(occupancy) == 1. + Note that only the highest occupancy atom is used by deeprank2 (see tools.pdb._add_atom_to_residue). + """ + def __init__( self, residue: Residue, @@ -37,25 +51,13 @@ def __init__( position: NDArray, occupancy: float, ): - """One atom in a PDBStructure. - - Args: - residue (:class:`Residue`): The residue that this atom belongs to. - name (str): Pdb atom name. - element (:class:`AtomicElement`): The chemical element. - position (np.array): Pdb position xyz of this atom. - occupancy (float): Pdb occupancy value. - This represents the proportion of structures where the atom is detected at a given position. - Sometimes a single atom can be detected at multiple positions. In that case separate structures exist where sum(occupancy) == 1. - Note that only the highest occupancy atom is used by deeprank2 (see tools.pdb._add_atom_to_residue). - """ self._residue = residue self._name = name self._element = element self._position = position self._occupancy = occupancy - def __eq__(self, other) -> bool: + def __eq__(self, other: Self) -> bool: if isinstance(other, Atom): return self._residue == other._residue and self._name == other._name return NotImplemented @@ -66,7 +68,7 @@ def __hash__(self) -> hash: def __repr__(self) -> str: return f"{self._residue} {self._name}" - def change_altloc(self, alternative_atom: Atom): + def change_altloc(self, alternative_atom: Atom) -> None: """Replace the atom's location by another atom's location.""" self._position = alternative_atom.position self._occupancy = alternative_atom.occupancy diff --git a/deeprank2/molstruct/pair.py b/deeprank2/molstruct/pair.py index 41b732693..463db4200 100644 --- a/deeprank2/molstruct/pair.py +++ b/deeprank2/molstruct/pair.py @@ -1,18 +1,21 @@ from abc import ABC from typing import Any +from typing_extensions import Self + from deeprank2.molstruct.atom import Atom from deeprank2.molstruct.residue import Residue class Pair: - def __init__(self, item1: Any, item2: Any): - """A hashable, comparable object for any set of two inputs where order doesn't matter. + """A hashable, comparable object for any set of two inputs where order doesn't matter. + + Args: + item1 (Any object): The pair's first object, must be convertable to string. + item2 (Any object): The pair's second object, must be convertable to string. + """ - Args: - item1 (Any object): The pair's first object, must be convertable to string. - item2 (Any object): The pair's second object, must be convertable to string. - """ + def __init__(self, item1: Any, item2: Any): # noqa: ANN401 self.item1 = item1 self.item2 = item2 @@ -24,7 +27,7 @@ def __hash__(self) -> hash: return hash(s1 + s2) return hash(s2 + s1) - def __eq__(self, other) -> bool: + def __eq__(self, other: Self) -> bool: """Compare the pairs as sets, so the order doesn't matter.""" if isinstance(other, Pair): return self.item1 == other.item1 and self.item2 == other.item2 or self.item1 == other.item2 and self.item2 == other.item1 diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index 2b1412676..17553b3fb 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import numpy as np +from typing_extensions import Self if TYPE_CHECKING: from numpy.typing import NDArray @@ -41,7 +42,7 @@ def __init__( self._insertion_code = insertion_code self._atoms = [] - def __eq__(self, other) -> bool: + def __eq__(self, other: Self) -> bool: if isinstance(other, Residue): return self._chain == other._chain and self._number == other._number and self._insertion_code == other._insertion_code return NotImplemented @@ -53,7 +54,8 @@ def get_pssm(self) -> PssmRow: """Load pssm info linked to the residue.""" pssm = self._chain.pssm if pssm is None: - raise FileNotFoundError(f"No pssm file found for Chain {self._chain}.") + msg = f"No pssm file found for Chain {self._chain}." + raise FileNotFoundError(msg) return pssm[self] @property @@ -83,7 +85,7 @@ def number_string(self) -> str: def insertion_code(self) -> str: return self._insertion_code - def add_atom(self, atom: Atom): + def add_atom(self, atom: Atom) -> None: self._atoms.append(atom) def __repr__(self) -> str: @@ -110,19 +112,21 @@ def get_center(self) -> NDArray: return alphas[0].position if len(self.atoms) == 0: - raise ValueError(f"Cannot get the center position from {self}, because it has no atoms") + msg = f"Cannot get the center position from {self}, because it has no atoms" + raise ValueError(msg) return np.mean([atom.position for atom in self.atoms], axis=0) class SingleResidueVariant: - def __init__(self, residue: Residue, variant_amino_acid: AminoAcid): - """A single residue mutation of a PDBStrcture. + """A single residue mutation of a PDBStrcture. - Args: - residue (Residue): the `Residue` object from the PDBStructure that is mutated. - variant_amino_acid (AminoAcid): the amino acid that the `Residue` is mutated into. - """ + Args: + residue (Residue): the `Residue` object from the PDBStructure that is mutated. + variant_amino_acid (AminoAcid): the amino acid that the `Residue` is mutated into. + """ + + def __init__(self, residue: Residue, variant_amino_acid: AminoAcid): self._residue = residue self._variant_amino_acid = variant_amino_acid diff --git a/deeprank2/molstruct/structure.py b/deeprank2/molstruct/structure.py index a5dcbc401..cbbcc0a40 100644 --- a/deeprank2/molstruct/structure.py +++ b/deeprank2/molstruct/structure.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +from typing_extensions import Self + if TYPE_CHECKING: from deeprank2.molstruct.atom import Atom from deeprank2.molstruct.residue import Residue @@ -24,7 +26,7 @@ def __init__(self, id_: str | None = None): self._id = id_ self._chains = {} - def __eq__(self, other) -> bool: + def __eq__(self, other: Self) -> bool: if isinstance(other, PDBStructure): return self._id == other._id return NotImplemented @@ -41,9 +43,10 @@ def has_chain(self, chain_id: str) -> bool: def get_chain(self, chain_id: str) -> Chain: return self._chains[chain_id] - def add_chain(self, chain: Chain): + def add_chain(self, chain: Chain) -> None: if chain.id in self._chains: - raise ValueError(f"Duplicate chain: {chain.id}") + msg = f"Duplicate chain: {chain.id}" + raise ValueError(msg) self._chains[chain.id] = chain @property @@ -90,10 +93,10 @@ def pssm(self) -> PssmRow: return self._pssm @pssm.setter - def pssm(self, pssm: PssmRow): + def pssm(self, pssm: PssmRow) -> None: self._pssm = pssm - def add_residue(self, residue: Residue): + def add_residue(self, residue: Residue) -> None: self._residues[(residue.number, residue.insertion_code)] = residue def has_residue(self, residue_number: int, insertion_code: str | None = None) -> bool: @@ -118,7 +121,7 @@ def get_atoms(self) -> list[Atom]: return atoms - def __eq__(self, other) -> bool: + def __eq__(self, other: Self) -> bool: if isinstance(other, Chain): return self._model == other._model and self._id == other._id return NotImplemented diff --git a/deeprank2/neuralnets/cnn/model3d.py b/deeprank2/neuralnets/cnn/model3d.py index 45805691b..21ff0face 100644 --- a/deeprank2/neuralnets/cnn/model3d.py +++ b/deeprank2/neuralnets/cnn/model3d.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from torch.autograd import Variable +# ruff: noqa: ANN001, ANN201, ANN202 + ###################################################################### # # Model automatically generated by modelGenerator @@ -21,7 +23,7 @@ # ---------------------------------------------------------------------- -class CnnRegression(torch.nn.Module): +class CnnRegression(torch.nn.Module): # noqa: D101 def __init__(self, num_features: int, box_shape: tuple[int]): super().__init__() @@ -74,7 +76,7 @@ def forward(self, data): # ---------------------------------------------------------------------- -class CnnClassification(torch.nn.Module): +class CnnClassification(torch.nn.Module): # noqa: D101 def __init__(self, num_features, box_shape): super().__init__() diff --git a/deeprank2/neuralnets/gnn/alignmentnet.py b/deeprank2/neuralnets/gnn/alignmentnet.py index a53cfabbd..077cd5dcf 100644 --- a/deeprank2/neuralnets/gnn/alignmentnet.py +++ b/deeprank2/neuralnets/gnn/alignmentnet.py @@ -1,10 +1,12 @@ import torch from torch import nn +# ruff: noqa: ANN001, ANN201 + __author__ = "Daniel-Tobias Rademaker" -class GNNLayer(nn.Module): +class GNNLayer(nn.Module): # noqa: D101 def __init__( self, nmb_edge_projection, @@ -12,7 +14,7 @@ def __init__( nmb_output_features, message_vector_length, nmb_mlp_neurons, - act_fn=nn.SiLU(), # noqa: B008 (function-call-in-default-argument) + act_fn=nn.SiLU(), # noqa: B008 is_last_layer=True, ): super().__init__() @@ -102,7 +104,7 @@ def output(self, hidden_features, get_attention=True): return output -class SuperGNN(nn.Module): +class SuperGNN(nn.Module): # noqa: D101 def __init__( self, nmb_edge_attr, @@ -113,7 +115,7 @@ def __init__( nmb_gnn_layers, nmb_output_features, message_vector_length, - act_fn=nn.SiLU(), # noqa: B008 (function-call-in-default-argument) + act_fn=nn.SiLU(), # noqa: B008 ): super().__init__() @@ -149,7 +151,7 @@ def __init__( is_last_layer=(gnn_layer == (nmb_gnn_layers - 1)), ) for gnn_layer in range(nmb_gnn_layers) - ] + ], ) # always use this function before running the GNN layers @@ -165,12 +167,12 @@ def run_through_network(self, edges, edge_attr, node_attr, with_output_attention for layer in self.modlist: node_attr = layer.update_nodes(edges, edge_attr, node_attr) if with_output_attention: - representations, attention = self.modlist[-1].output(node_attr, True) # noqa: FBT003 (boolean-positional-value-in-call) + representations, attention = self.modlist[-1].output(node_attr, True) # (boolean-positional-value-in-call) return representations, attention - return self.modlist[-1].output(node_attr, True) # noqa: FBT003 (boolean-positional-value-in-call) + return self.modlist[-1].output(node_attr, True) # (boolean-positional-value-in-call) -class AlignmentGNN(SuperGNN): +class AlignmentGNN(SuperGNN): # noqa: D101 def __init__( self, nmb_edge_attr, @@ -181,7 +183,7 @@ def __init__( nmb_mlp_neurons, nmb_gnn_layers, nmb_edge_projection, - act_fn=nn.SiLU(), # noqa: B008 (function-call-in-default-argument) + act_fn=nn.SiLU(), # noqa: B008 ): super().__init__( nmb_edge_attr, diff --git a/deeprank2/neuralnets/gnn/foutnet.py b/deeprank2/neuralnets/gnn/foutnet.py index 2ae898860..4a68ef1cc 100644 --- a/deeprank2/neuralnets/gnn/foutnet.py +++ b/deeprank2/neuralnets/gnn/foutnet.py @@ -8,6 +8,8 @@ from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster +# ruff: noqa: ANN001, ANN201 + class FoutLayer(torch.nn.Module): """FoutLayer. @@ -40,7 +42,7 @@ def __init__(self, in_channels: int, out_channels: int, bias: bool = True): self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: size = self.in_channels uniform(size, self.wc) uniform(size, self.wn) @@ -70,12 +72,12 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class FoutNet(torch.nn.Module): +class FoutNet(torch.nn.Module): # noqa: D101 def __init__( self, input_shape, output_shape=1, - input_shape_edge=None, # noqa: ARG002 (unused argument) + input_shape_edge=None, # noqa: ARG002 ): super().__init__() diff --git a/deeprank2/neuralnets/gnn/ginet.py b/deeprank2/neuralnets/gnn/ginet.py index 95d897aff..92d2d252f 100644 --- a/deeprank2/neuralnets/gnn/ginet.py +++ b/deeprank2/neuralnets/gnn/ginet.py @@ -7,8 +7,10 @@ from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster +# ruff: noqa: ANN001, ANN201 -class GINetConvLayer(torch.nn.Module): + +class GINetConvLayer(torch.nn.Module): # noqa: D101 def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): super().__init__() @@ -20,7 +22,7 @@ def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False self.fc_attention = nn.Linear(2 * self.out_channels + number_edge_features, 1, bias=bias) self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: size = self.in_channels uniform(size, self.fc.weight) uniform(size, self.fc_attention.weight) @@ -52,7 +54,7 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class GINet(torch.nn.Module): +class GINet(torch.nn.Module): # noqa: D101 # input_shape -> number of node input features # output_shape -> number of output value per graph # input_shape_edge -> number of edge input features diff --git a/deeprank2/neuralnets/gnn/ginet_nocluster.py b/deeprank2/neuralnets/gnn/ginet_nocluster.py index 849d617f3..78102e5da 100644 --- a/deeprank2/neuralnets/gnn/ginet_nocluster.py +++ b/deeprank2/neuralnets/gnn/ginet_nocluster.py @@ -4,8 +4,10 @@ from torch_geometric.nn.inits import uniform from torch_scatter import scatter_mean, scatter_sum +# ruff: noqa: ANN001, ANN201 -class GINetConvLayer(torch.nn.Module): + +class GINetConvLayer(torch.nn.Module): # noqa: D101 def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): super().__init__() @@ -17,7 +19,7 @@ def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False self.fc_attention = nn.Linear(2 * self.out_channels + number_edge_features, 1, bias=bias) self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: size = self.in_channels uniform(size, self.fc.weight) uniform(size, self.fc_attention.weight) @@ -49,7 +51,7 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class GINet(torch.nn.Module): +class GINet(torch.nn.Module): # noqa: D101 # input_shape -> number of node input features # output_shape -> number of output value per graph # input_shape_edge -> number of edge input features diff --git a/deeprank2/neuralnets/gnn/naive_gnn.py b/deeprank2/neuralnets/gnn/naive_gnn.py index 873f3005c..50ab2f0ff 100644 --- a/deeprank2/neuralnets/gnn/naive_gnn.py +++ b/deeprank2/neuralnets/gnn/naive_gnn.py @@ -4,8 +4,10 @@ from torch.nn import Linear, Module, ReLU, Sequential from torch_scatter import scatter_mean, scatter_sum +# ruff: noqa: ANN001, ANN201 -class NaiveConvolutionalLayer(Module): + +class NaiveConvolutionalLayer(Module): # noqa: D101 def __init__(self, count_node_features, count_edge_features): super().__init__() message_size = 32 @@ -29,7 +31,7 @@ def forward(self, node_features, edge_node_indices, edge_features): return self._node_mlp(node_input) -class NaiveNetwork(Module): +class NaiveNetwork(Module): # noqa: D101 def __init__(self, input_shape: int, output_shape: int, input_shape_edge: int): """NaiveNetwork. diff --git a/deeprank2/neuralnets/gnn/sgat.py b/deeprank2/neuralnets/gnn/sgat.py index b594181cc..8b8b702e4 100644 --- a/deeprank2/neuralnets/gnn/sgat.py +++ b/deeprank2/neuralnets/gnn/sgat.py @@ -8,6 +8,8 @@ from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster +# ruff: noqa: ANN001, ANN201 + class SGraphAttentionLayer(torch.nn.Module): """SGraphAttentionLayer. @@ -23,9 +25,15 @@ class SGraphAttentionLayer(torch.nn.Module): out_channels (int): Size of each output sample. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. Defaults to True. - """ # noqa: D301 (escape-sequence-in-docstring) + """ # noqa: D301 - def __init__(self, in_channels: int, out_channels: int, bias: bool = True, undirected=True): + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + undirected: bool = True, + ): super().__init__() self.in_channels = in_channels @@ -41,7 +49,7 @@ def __init__(self, in_channels: int, out_channels: int, bias: bool = True, undir self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: size = 2 * self.in_channels uniform(size, self.weight) uniform(size, self.bias) @@ -80,12 +88,12 @@ def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" -class SGAT(torch.nn.Module): +class SGAT(torch.nn.Module): # noqa:D101 def __init__( self, input_shape, output_shape=1, - input_shape_edge=None, # noqa: ARG002 (unused argument) + input_shape_edge=None, # noqa: ARG002 ): super().__init__() diff --git a/deeprank2/query.py b/deeprank2/query.py index 484074157..429c2d273 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -75,7 +75,8 @@ def __post_init__(self): self.max_edge_length = 4.5 if not self.max_edge_length else self.max_edge_length self.influence_radius = 4.5 if not self.influence_radius else self.influence_radius else: - raise ValueError(f"Invalid resolution given ({self.resolution}). Must be one of {VALID_RESOLUTIONS}") + msg = f"Invalid resolution given ({self.resolution}). Must be one of {VALID_RESOLUTIONS}" + raise ValueError(msg) if not isinstance(self.chain_ids, list): self.chain_ids = [self.chain_ids] @@ -86,7 +87,7 @@ def __post_init__(self): if value is None and f.default_factory is not MISSING: setattr(self, f.name, f.default_factory()) - def _set_graph_targets(self, graph: Graph): + def _set_graph_targets(self, graph: Graph) -> None: """Copy target data from query to graph.""" for target_name, target_data in self.targets.items(): graph.targets[target_name] = target_data @@ -97,14 +98,14 @@ def _load_structure(self) -> PDBStructure: try: structure = get_structure(pdb, self.model_id) finally: - pdb._close() # noqa: SLF001 (private-member-access) + pdb._close() # noqa: SLF001 # read the pssm if self._pssm_required: self._load_pssm_data(structure) return structure - def _load_pssm_data(self, structure: PDBStructure): + def _load_pssm_data(self, structure: PDBStructure) -> None: self._check_pssm() for chain in structure.chains: if chain.id in self.pssm_paths: @@ -112,7 +113,7 @@ def _load_pssm_data(self, structure: PDBStructure): with open(pssm_path, encoding="utf-8") as f: chain.pssm = parse_pssm(f, chain) - def _check_pssm(self, verbosity: Literal[0, 1, 2] = 0): + def _check_pssm(self, verbosity: Literal[0, 1, 2] = 0) -> None: # noqa: C901 """Checks whether information stored in pssm file matches the corresponding pdb file. Args: @@ -128,7 +129,8 @@ def _check_pssm(self, verbosity: Literal[0, 1, 2] = 0): ValueError: Raised if info between pdb file and pssm file doesn't match or if no pssms were provided """ if not self.pssm_paths: - raise ValueError("No pssm paths provided for conservation feature module.") + msg = "No pssm paths provided for conservation feature module." + raise ValueError(msg) # load residues from pssm and pdb files pssm_file_residues = {} @@ -146,7 +148,7 @@ def _check_pssm(self, verbosity: Literal[0, 1, 2] = 0): try: if pdb_file_residues[residue] != pssm_file_residues[residue]: mismatches.append(residue) - except KeyError: # noqa: PERF203 (try-except-in-loop) + except KeyError: # noqa: PERF203 missing_entries.append(residue) # generate error message @@ -155,11 +157,11 @@ def _check_pssm(self, verbosity: Literal[0, 1, 2] = 0): if verbosity: if len(mismatches) > 0: error_message = error_message + f"\n\t{len(mismatches)} entries are incorrect." - if verbosity == 2: + if verbosity == 2: # noqa: PLR2004 error_message = error_message[-1] + f":\n\t{missing_entries}" if len(missing_entries) > 0: error_message = error_message + f"\n\t{len(missing_entries)} entries are missing." - if verbosity == 2: + if verbosity == 2: # noqa: PLR2004 error_message = error_message[-1] + f":\n\t{missing_entries}" # raise exception (or warning) @@ -174,7 +176,7 @@ def model_id(self) -> str: return self._model_id @model_id.setter - def model_id(self, value: str): + def model_id(self, value: str) -> None: self._model_id = value def __repr__(self) -> str: @@ -207,10 +209,12 @@ def build( return graph def _build_helper(self) -> Graph: - raise NotImplementedError("Must be defined in child classes.") + msg = "Must be defined in child classes." + raise NotImplementedError(msg) def get_query_id(self) -> str: - raise NotImplementedError("Must be defined in child classes.") + msg = "Must be defined in child classes." + raise NotImplementedError(msg) @dataclass(kw_only=True) @@ -282,7 +286,8 @@ def _build_helper(self) -> Graph: variant_residue = residue break if variant_residue is None: - raise ValueError(f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}") + msg = f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}" + raise ValueError(msg) self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) residues = get_surrounding_residues( structure, @@ -309,7 +314,8 @@ def _build_helper(self) -> Graph: graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length) else: - raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") + msg = f"No function exists to build graphs with resolution of {self.resolution}." + raise NotImplementedError(msg) graph.center = variant_residue.get_center() return graph @@ -336,9 +342,9 @@ class ProteinProteinInterfaceQuery(Query): def __post_init__(self): super().__post_init__() - if len(self.chain_ids) != 2: + if len(self.chain_ids) != 2: # noqa: PLR2004 raise ValueError( - "`chain_ids` must contain exactly 2 chains for `ProteinProteinInterfaceQuery` objects, " + f"but {len(self.chain_ids)} was/were given." + "`chain_ids` must contain exactly 2 chains for `ProteinProteinInterfaceQuery` objects, " + f"but {len(self.chain_ids)} was/were given.", ) def get_query_id(self) -> str: @@ -361,7 +367,8 @@ def _build_helper(self) -> Graph: self.influence_radius, ) if len(contact_atoms) == 0: - raise ValueError("No contact atoms found") + msg = "No contact atoms found" + raise ValueError(msg) # build the graph if self.resolution == "atom": @@ -417,7 +424,7 @@ def add( query: Query, verbose: bool = False, warn_duplicate: bool = True, - ): + ) -> None: """Add a new query to the collection. Args: @@ -440,7 +447,7 @@ def add( self._queries.append(query) - def export_dict(self, dataset_path: str): + def export_dict(self, dataset_path: str) -> None: """Exports the colection of all queries to a dictionary file. Args: @@ -463,7 +470,7 @@ def __iter__(self) -> Iterator[Query]: def __len__(self) -> int: return len(self._queries) - def _process_one_query(self, query: Query): + def _process_one_query(self, query: Query) -> None: """Only one process may access an hdf5 file at a time.""" try: output_path = f"{self._prefix}-{os.getpid()}.hdf5" @@ -490,17 +497,14 @@ def _process_one_query(self, query: Query): except (ValueError, AttributeError, KeyError, TimeoutError) as e: _log.warning( f"\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e})," - " and it has not been written to the hdf5 file. More details below:" + " and it has not been written to the hdf5 file. More details below:", ) _log.exception(e) def process( self, prefix: str = "processed-queries", - feature_modules: list[ModuleType, str] | ModuleType | str | Literal["all"] = [ # noqa: B006, PYI051 (mutable-argument-default, redundant-literal-union) - components, - contact, - ], + feature_modules: list[ModuleType, str] | ModuleType | str | None = None, cpu_count: int | None = None, combine_output: bool = True, grid_settings: GridSettings | None = None, @@ -533,6 +537,7 @@ def process( list[str]: The list of paths of the generated HDF5 files. """ # set defaults + feature_modules = feature_modules or [components, contact] self._prefix = "processed-queries" if not prefix else re.sub(".hdf5$", "", prefix) # scrape extension if present max_cpus = os.cpu_count() @@ -548,7 +553,8 @@ def process( self._grid_map_method = grid_map_method if grid_augmentation_count < 0: - raise ValueError(f"`grid_augmentation_count` cannot be negative, but was given as {grid_augmentation_count}") + msg = f"`grid_augmentation_count` cannot be negative, but was given as {grid_augmentation_count}" + raise ValueError(msg) self._grid_augmentation_count = grid_augmentation_count _log.info(f"Creating pool function to process {len(self)} queries...") @@ -569,7 +575,7 @@ def process( return output_paths - def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str | Literal["all"]) -> list[str]: # noqa: PYI051 (redundant-literal-union) + def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str) -> list[str]: """Convert `feature_modules` to list[str] irrespective of input type. Raises: @@ -584,9 +590,11 @@ def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleTy if isinstance(feature_modules, list): invalid_inputs = [type(el) for el in feature_modules if not isinstance(el, str | ModuleType)] if invalid_inputs: - raise TypeError(f"`feature_modules` contains invalid input ({invalid_inputs}). Only `str` and `ModuleType` are accepted.") + msg = f"`feature_modules` contains invalid input ({invalid_inputs}). Only `str` and `ModuleType` are accepted." + raise TypeError(msg) return [ re.sub(".py$", "", m) if isinstance(m, str) else os.path.basename(m.__file__)[:-3] # for ModuleTypes for m in feature_modules ] - raise TypeError(f"`feature_modules` has received an invalid input type: {type(feature_modules)}. Only `str` and `ModuleType` are accepted.") + msg = f"`feature_modules` has received an invalid input type: {type(feature_modules)}. Only `str` and `ModuleType` are accepted." + raise TypeError(msg) diff --git a/deeprank2/tools/target.py b/deeprank2/tools/target.py index 059712c7d..4721633a2 100644 --- a/deeprank2/tools/target.py +++ b/deeprank2/tools/target.py @@ -1,4 +1,5 @@ import glob +import logging import os import h5py @@ -7,13 +8,16 @@ from deeprank2.domain import targetstorage as targets +_log = logging.getLogger(__name__) +MIN_IRMS_FOR_BINARY = 4 -def add_target( + +def add_target( # noqa: C901 graph_path: str | list[str], target_name: str, target_list: str, sep: str = " ", -): +) -> None: """Add a target to all the graphs in hdf5 files. Args: @@ -45,18 +49,21 @@ def add_target( elif isinstance(graph_path, list): graphs = graph_path else: - raise TypeError("Incorrect input passed.") + msg = "Incorrect input passed." + raise TypeError(msg) for hdf5 in graphs: - print(hdf5) + _log.info(hdf5) if not os.path.isfile(hdf5): - raise FileNotFoundError(f"File {hdf5} not found.") + msg = f"File {hdf5} not found." + raise FileNotFoundError(msg) try: f5 = h5py.File(hdf5, "a") for model in target_dict: if model not in f5: - raise ValueError(f"{hdf5} does not contain an entry named {model}.") # noqa: TRY301 (raise-within-try) + msg = f"{hdf5} does not contain an entry named {model}." + raise ValueError(msg) # noqa: TRY301 try: model_gp = f5[model] if targets.VALUES not in model_gp: @@ -67,12 +74,12 @@ def add_target( del group[target_name] # Create the target group.create_dataset(target_name, data=target_dict[model]) - except BaseException: # noqa: BLE001 (blind-except) - print(f"no graph for {model}") + except BaseException: # noqa: BLE001 + _log.info(f"no graph for {model}") f5.close() - except BaseException: # noqa: BLE001 (blind-except) - print(f"no graph for {hdf5}") + except BaseException: # noqa: BLE001 + _log.info(f"no graph for {hdf5}") def compute_ppi_scores( @@ -113,7 +120,7 @@ def compute_ppi_scores( scores[targets.FNAT] = sim.compute_fnat_fast() scores[targets.DOCKQ] = sim.compute_DockQScore(scores[targets.FNAT], scores[targets.LRMSD], scores[targets.IRMSD]) - scores[targets.BINARY] = scores[targets.IRMSD] < 4.0 + scores[targets.BINARY] = scores[targets.IRMSD] < MIN_IRMS_FOR_BINARY scores[targets.CAPRI] = 4 for thr, val in zip([4.0, 2.0, 1.0], [3, 2, 1], strict=True): diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 3a5ef32d3..3f97dbb4d 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -4,6 +4,7 @@ import re import warnings from time import time +from typing import Any import h5py import numpy as np @@ -25,9 +26,36 @@ class Trainer: - def __init__( # noqa: PLR0915 (too-many-statements) + """Class from which the network is trained, evaluated and tested. + + Args: + neuralnet (child class of :class:`torch.nn.Module`, optional): Neural network class (ex. :class:`GINet`, :class:`Foutnet` etc.). + It should subclass :class:`torch.nn.Module`, and it shouldn't be specific to regression or classification + in terms of output shape (:class:`Trainer` class takes care of formatting the output shape according to the task). + More specifically, in classification task cases, softmax shouldn't be used as the last activation function. + Defaults to None. + dataset_train (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Training set used during training. + Can't be None if pretrained_model is also None. Defaults to None. + dataset_val (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Evaluation set used during training. + If None, training set will be split randomly into training set and validation set during training, using val_size parameter. + Defaults to None. + dataset_test (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Independent evaluation set. Defaults to None. + val_size (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. + Only used if dataset_val is not specified. Can be set to 0 if no validation set is needed. Defaults to None (in _divide_dataset function). + test_size (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for test dataset. + Only used if dataset_test is not specified. Can be set to 0 if no test set is needed. Defaults to None. + class_weights (bool, optional): Assign class weights based on the dataset content. Defaults to False. + pretrained_model (str | None, optional): Path to pre-trained model. Defaults to None. + cuda (bool, optional): Whether to use CUDA. Defaults to False. + ngpu (int, optional): Number of GPU to be used. Defaults to 0. + output_exporters (list[OutputExporter] | None, optional): The output exporters to use for saving/exploring/plotting predictions/targets/losses + over the epochs. If None, defaults to :class:`HDF5OutputExporter`, which saves all the results in an .HDF5 file stored in ./output directory. + Defaults to None. + """ + + def __init__( # noqa: PLR0915, C901 self, - neuralnet=None, + neuralnet: nn.Module = None, dataset_train: GraphDataset | GridDataset | None = None, dataset_val: GraphDataset | GridDataset | None = None, dataset_test: GraphDataset | GridDataset | None = None, @@ -39,32 +67,6 @@ def __init__( # noqa: PLR0915 (too-many-statements) ngpu: int = 0, output_exporters: list[OutputExporter] | None = None, ): - """Class from which the network is trained, evaluated and tested. - - Args: - neuralnet (child class of :class:`torch.nn.Module`, optional): Neural network class (ex. :class:`GINet`, :class:`Foutnet` etc.). - It should subclass :class:`torch.nn.Module`, and it shouldn't be specific to regression or classification - in terms of output shape (:class:`Trainer` class takes care of formatting the output shape according to the task). - More specifically, in classification task cases, softmax shouldn't be used as the last activation function. - Defaults to None. - dataset_train (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Training set used during training. - Can't be None if pretrained_model is also None. Defaults to None. - dataset_val (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Evaluation set used during training. - If None, training set will be split randomly into training set and validation set during training, using val_size parameter. - Defaults to None. - dataset_test (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Independent evaluation set. Defaults to None. - val_size (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. - Only used if dataset_val is not specified. Can be set to 0 if no validation set is needed. Defaults to None (in _divide_dataset function). - test_size (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for test dataset. - Only used if dataset_test is not specified. Can be set to 0 if no test set is needed. Defaults to None. - class_weights (bool, optional): Assign class weights based on the dataset content. Defaults to False. - pretrained_model (str | None, optional): Path to pre-trained model. Defaults to None. - cuda (bool, optional): Whether to use CUDA. Defaults to False. - ngpu (int, optional): Number of GPU to be used. Defaults to 0. - output_exporters (list[OutputExporter] | None, optional): The output exporters to use for saving/exploring/plotting predictions/targets/losses - over the epochs. If None, defaults to :class:`HDF5OutputExporter`, which saves all the results in an .HDF5 file stored in ./output directory. - Defaults to None. - """ self.neuralnet = neuralnet self.pretrained_model = pretrained_model @@ -85,15 +87,16 @@ def __init__( # noqa: PLR0915 (too-many-statements) and that you are running on GPUs.\n --> To turn CUDA off set cuda=False in Trainer.\n --> Aborting the experiment \n\n' - """ + """, ) - raise ValueError( - """ + msg = """ --> CUDA not detected: Make sure that CUDA is installed and that you are running on GPUs.\n --> To turn CUDA off set cuda=False in Trainer.\n --> Aborting the experiment \n\n' """ + raise ValueError( + msg, ) else: self.device = torch.device("cpu") @@ -103,14 +106,15 @@ def __init__( # noqa: PLR0915 (too-many-statements) --> CUDA not detected. Set cuda=True in Trainer to turn CUDA on.\n --> Aborting the experiment \n\n - """ + """, ) - raise ValueError( - """ + msg = """ --> CUDA not detected. Set cuda=True in Trainer to turn CUDA on.\n --> Aborting the experiment \n\n """ + raise ValueError( + msg, ) _log.info(f"Device set to {self.device}.") @@ -129,9 +133,11 @@ def __init__( # noqa: PLR0915 (too-many-statements) if self.pretrained_model is None: if self.dataset_train is None: - raise ValueError("No training data specified. Training data is required if there is no pretrained model.") + msg = "No training data specified. Training data is required if there is no pretrained model." + raise ValueError(msg) if self.neuralnet is None: - raise ValueError("No neural network specified. Specifying a model framework is required if there is no pretrained model.") + msg = "No neural network specified. Specifying a model framework is required if there is no pretrained model." + raise ValueError(msg) self._init_from_dataset(self.dataset_train) self.optimizer = None @@ -140,7 +146,8 @@ def __init__( # noqa: PLR0915 (too-many-statements) self.epoch_saved_model = None if self.target is None: - raise ValueError("No target set. You need to choose a target (set in the dataset) for training.") + msg = "No target set. You need to choose a target (set in the dataset) for training." + raise ValueError(msg) self._load_model() @@ -159,13 +166,16 @@ def __init__( # noqa: PLR0915 (too-many-statements) if self.dataset_test is not None: self._precluster(self.dataset_test) else: - raise ValueError(f"Invalid node clustering method: {self.clustering_method}. Please set clustering_method to 'mcl', 'louvain' or None.") + msg = f"Invalid node clustering method: {self.clustering_method}. Please set clustering_method to 'mcl', 'louvain' or None." + raise ValueError(msg) else: if self.neuralnet is None: - raise ValueError("No neural network class found. Please add it to complete loading the pretrained model.") + msg = "No neural network class found. Please add it to complete loading the pretrained model." + raise ValueError(msg) if self.dataset_test is None: - raise ValueError("No dataset_test found. Please add it to evaluate the pretrained model.") + msg = "No dataset_test found. Please add it to evaluate the pretrained model." + raise ValueError(msg) if self.dataset_train is not None: self.dataset_train = None _log.warning("Pretrained model loaded: dataset_train will be ignored.") @@ -176,7 +186,7 @@ def __init__( # noqa: PLR0915 (too-many-statements) self._load_params() self._load_pretrained_model() - def _init_output_exporters(self, output_exporters: list[OutputExporter] | None): + def _init_output_exporters(self, output_exporters: list[OutputExporter] | None) -> None: if output_exporters is not None: self._output_exporters = OutputExporterCollection(*output_exporters) else: @@ -189,7 +199,7 @@ def _init_datasets( dataset_test: GraphDataset | GridDataset | None, val_size: int | float | None, test_size: int | float | None, - ): + ) -> None: self._check_dataset_equivalence(dataset_train, dataset_val, dataset_test) self.dataset_train = dataset_train @@ -211,7 +221,7 @@ def _init_datasets( else: _log.warning("Validation dataset was provided to Trainer; val_size parameter is ignored.") - def _init_from_dataset(self, dataset: GraphDataset | GridDataset): + def _init_from_dataset(self, dataset: GraphDataset | GridDataset) -> None: if isinstance(dataset, GraphDataset): self.clustering_method = dataset.clustering_method self.node_features = dataset.node_features @@ -230,7 +240,8 @@ def _init_from_dataset(self, dataset: GraphDataset | GridDataset): self.means = None self.devs = None else: - raise TypeError(f"Incorrect `dataset` type provided: {type(dataset)}. Please provide a `GridDataset` or `GraphDataset` object instead.") + msg = f"Incorrect `dataset` type provided: {type(dataset)}. Please provide a `GridDataset` or `GraphDataset` object instead." + raise TypeError(msg) self.target = dataset.target self.target_transform = dataset.target_transform @@ -238,23 +249,30 @@ def _init_from_dataset(self, dataset: GraphDataset | GridDataset): self.classes = dataset.classes self.classes_to_index = dataset.classes_to_index - def _load_model(self): + def _load_model(self) -> None: """Loads the neural network model.""" self._put_model_to_device(self.dataset_train) self.configure_optimizers() self.set_lossfunction() - def _check_dataset_equivalence(self, dataset_train, dataset_val, dataset_test): + def _check_dataset_equivalence( + self, + dataset_train: GraphDataset | GridDataset, + dataset_val: GraphDataset | GridDataset, + dataset_test: GraphDataset | GridDataset, + ) -> None: """Check dataset_train type and train_source parameter settings.""" # dataset_train is None when pretrained_model is set if dataset_train is None: # only check the test dataset if dataset_test is None: - raise ValueError("Please provide at least a train or test dataset") + msg = "Please provide at least a train or test dataset" + raise ValueError(msg) else: # Make sure train dataset has valid type if not isinstance(dataset_train, GraphDataset) and not isinstance(dataset_train, GridDataset): - raise TypeError(f"train dataset is not the right type {type(dataset_train)}. Make sure it's either GraphDataset or GridDataset") + msg = f"train dataset is not the right type {type(dataset_train)}. Make sure it's either GraphDataset or GridDataset" + raise TypeError(msg) if dataset_val is not None: self._check_dataset_value( @@ -270,18 +288,25 @@ def _check_dataset_equivalence(self, dataset_train, dataset_val, dataset_test): type_dataset="test", ) - def _check_dataset_value(self, dataset_train, dataset_check, type_dataset): + def _check_dataset_value( + self, + dataset_train: GraphDataset | GridDataset, + dataset_check: GraphDataset | GridDataset, + type_dataset: str, + ) -> None: """Check valid/test dataset settings.""" # Check train_source parameter in valid/test is set. if dataset_check.train_source is None: - raise ValueError(f"{type_dataset} dataset has train_source parameter set to None. Make sure to set it as a valid training data source.") + msg = f"{type_dataset} dataset has train_source parameter set to None. Make sure to set it as a valid training data source." + raise ValueError(msg) # Check train_source parameter in valid/test is equivalent to train which passed to Trainer. if dataset_check.train_source != dataset_train: + msg = f"{type_dataset} dataset has different train_source parameter from Trainer. Make sure to assign equivalent train_source in Trainer." raise ValueError( - f"{type_dataset} dataset has different train_source parameter from Trainer. Make sure to assign equivalent train_source in Trainer." + msg, ) - def _load_pretrained_model(self): + def _load_pretrained_model(self) -> None: """Loads pretrained model.""" self.test_loader = DataLoader(self.dataset_test, pin_memory=self.cuda) _log.info("Testing set loaded\n") @@ -296,7 +321,7 @@ def _load_pretrained_model(self): self.optimizer.load_state_dict(self.opt_loaded_state_dict) self.model.load_state_dict(self.model_load_state_dict) - def _precluster(self, dataset: GraphDataset): + def _precluster(self, dataset: GraphDataset) -> None: """Pre-clusters nodes of the graphs.""" for fname, mol in tqdm(dataset.index_entries): data = dataset.load_one_graph(fname, mol) @@ -306,7 +331,7 @@ def _precluster(self, dataset: GraphDataset): try: _log.info(f"deleting {mol}") del f5[mol] - except BaseException: # noqa: BLE001 (blind-except) + except BaseException: # noqa: BLE001 _log.info(f"{mol} not found") f5.close() continue @@ -327,7 +352,7 @@ def _precluster(self, dataset: GraphDataset): f5.close() - def _put_model_to_device(self, dataset: GraphDataset | GridDataset): + def _put_model_to_device(self, dataset: GraphDataset | GridDataset) -> None: """ Puts the model on the available device. @@ -372,18 +397,21 @@ def _put_model_to_device(self, dataset: GraphDataset | GridDataset): # check for compatibility for output_exporter in self._output_exporters: if not output_exporter.is_compatible_with(self.output_shape, target_shape): - raise ValueError( + msg = ( f"Output exporter of type {type(output_exporter)}\n\t" f"is not compatible with output shape {self.output_shape}\n\t" f"and target shape {target_shape}." ) + raise ValueError( + msg, + ) def configure_optimizers( self, optimizer: torch.optim = None, lr: float = 0.001, weight_decay: float = 1e-05, - ): + ) -> None: """ Configure optimizer and its main parameters. @@ -409,11 +437,11 @@ def configure_optimizers( _log.info("Invalid optimizer. Please use only optimizers classes from torch.optim package.") raise - def set_lossfunction( + def set_lossfunction( # noqa: C901 self, - lossfunction=None, + lossfunction: nn.modules.loss._Loss | None = None, override_invalid: bool = False, - ): + ) -> None: """ Set the loss function. @@ -433,12 +461,12 @@ def set_lossfunction( default_regression_loss = nn.MSELoss default_classification_loss = nn.CrossEntropyLoss - def _invalid_loss(): + def _invalid_loss() -> None: if override_invalid: _log.warning( f"The provided loss function ({lossfunction}) is not appropriate for {self.task} tasks.\n\t" "You have set override_invalid to True, so the training will run with this loss function nonetheless.\n\t" - "This will likely cause other errors or exceptions down the line." + "This will likely cause other errors or exceptions down the line.", ) else: invalid_loss_error = ( @@ -490,7 +518,7 @@ def _invalid_loss(): else: self.lossfunction = lossfunction # weights will be set in the train() method - def train( # noqa: PLR0915 (too-many-statements) + def train( # noqa: PLR0915, C901 self, nepoch: int = 1, batch_size: int = 32, @@ -502,7 +530,7 @@ def train( # noqa: PLR0915 (too-many-statements) num_workers: int = 0, best_model: bool = True, filename: str | None = "model.pth.tar", - ): + ) -> None: """ Performs the training of the model. @@ -531,7 +559,8 @@ def train( # noqa: PLR0915 (too-many-statements) If None, the model is not saved. Defaults to 'model.pth.tar'. """ if self.dataset_train is None: - raise ValueError("No training dataset provided.") + msg = "No training dataset provided." + raise ValueError(msg) self.data_type = type(self.dataset_train) self.batch_size_train = batch_size @@ -560,7 +589,7 @@ def train( # noqa: PLR0915 (too-many-statements) _log.info("No validation set provided\n") _log.warning( "Training data will be used both for learning and model selection, which may lead to overfitting.\n" - "It is usually preferable to use a validation set during the training phase." + "It is usually preferable to use a validation set during the training phase.", ) # Assign weights to each class @@ -607,7 +636,8 @@ def train( # noqa: PLR0915 (too-many-statements) self._eval(self.train_loader, 0, "training") if validate: if self.valid_loader is None: - raise ValueError("No validation dataset provided.") + msg = "No validation dataset provided." + raise ValueError(msg) self._eval(self.valid_loader, 0, "validation") # Loop over epochs @@ -650,7 +680,7 @@ def train( # noqa: PLR0915 (too-many-statements) if not saved_model: warnings.warn( "A model has been saved but the validation and/or the training losses were NaN;\n\t" - "try to increase the cutoff distance during the data processing or the number of data points during the training." + "try to increase the cutoff distance during the data processing or the number of data points during the training.", ) # Now that the training loop is over, save the model @@ -680,7 +710,7 @@ def _epoch(self, epoch_number: int, pass_name: str) -> float | None: t0 = time() for data_batch in self.train_loader: if self.cuda: - data_batch = data_batch.to(self.device, non_blocking=True) # noqa: PLW2901 (redefined-loop-name) + data_batch = data_batch.to(self.device, non_blocking=True) # noqa: PLW2901 self.optimizer.zero_grad() pred = self.model(data_batch) pred, data_batch.y = self._format_output(pred, data_batch.y) @@ -750,7 +780,7 @@ def _eval( t0 = time() for data_batch in loader: if self.cuda: - data_batch = data_batch.to(self.device, non_blocking=True) # noqa: PLW2901 (redefined-loop-name) + data_batch = data_batch.to(self.device, non_blocking=True) # noqa: PLW2901 pred = self.model(data_batch) pred, y = self._format_output(pred, data_batch.y) @@ -794,7 +824,7 @@ def _eval( return eval_loss @staticmethod - def _log_epoch_data(stage: str, loss: float, time: float): + def _log_epoch_data(stage: str, loss: float, time: float) -> None: """ Prints the data of each epoch. @@ -805,7 +835,7 @@ def _log_epoch_data(stage: str, loss: float, time: float): """ _log.info(f"{stage} loss {loss} | time {time}") - def _format_output(self, pred, target=None): + def _format_output(self, pred, target=None): # noqa: ANN001, ANN202 """Format the network output depending on the task (classification/regression).""" if (self.task == targets.CLASSIF) and (target is not None): # For categorical cross entropy, the target must be a one-dimensional tensor @@ -813,17 +843,23 @@ def _format_output(self, pred, target=None): target = torch.tensor([self.classes_to_index[x] if isinstance(x, str) else self.classes_to_index[int(x)] for x in target]) if isinstance(self.lossfunction, nn.BCELoss | nn.BCEWithLogitsLoss): # # pred must be in (0,1) range and target must be float with same shape as pred - raise ValueError( + msg = ( "BCELoss and BCEWithLogitsLoss are currently not supported.\n\t" "For further details see: https://github.com/DeepRank/deeprank2/issues/318" ) + raise ValueError( + msg, + ) if isinstance(self.lossfunction, losses.classification_losses) and not isinstance(self.lossfunction, losses.classification_tested): - raise ValueError( + msg = ( f"{self.lossfunction} is currently not supported.\n\t" f"Supported loss functions for classification: {losses.classification_tested}.\n\t" "Implementation of other loss functions requires adaptation of Trainer._format_output." ) + raise ValueError( + msg, + ) elif self.task == targets.REGRESS: pred = pred.reshape(-1) @@ -837,7 +873,7 @@ def test( self, batch_size: int = 32, num_workers: int = 0, - ): + ) -> None: """ Performs the testing of the model. @@ -848,7 +884,8 @@ def test( Defaults to 0. """ if (not self.pretrained_model) and (not self.model_load_state_dict): - raise ValueError("No pretrained model provided and no training performed. Please provide a pretrained model or train the model before testing.") + msg = "No pretrained model provided and no training performed. Please provide a pretrained model or train the model before testing." + raise ValueError(msg) self.batch_size_test = batch_size @@ -864,13 +901,14 @@ def test( _log.info("Testing set loaded\n") else: _log.error("No test dataset provided.") - raise ValueError("No test dataset provided.") + msg = "No test dataset provided." + raise ValueError(msg) with self._output_exporters: # Run test self._eval(self.test_loader, self.epoch_saved_model, "testing") - def _load_params(self): + def _load_params(self) -> None: """Loads the parameters of a pretrained model.""" if torch.cuda.is_available(): state = torch.load(self.pretrained_model) @@ -907,7 +945,7 @@ def _load_params(self): self.cuda = state["cuda"] self.ngpu = state["ngpu"] - def _save_model(self): + def _save_model(self) -> dict[str, Any]: """ Saves the model to a file. @@ -980,12 +1018,15 @@ def _divide_dataset( elif isinstance(splitsize, int): n_split = splitsize else: - raise TypeError(f"type(splitsize) must be float, int or None ({type(splitsize)} detected.)") + msg = f"type(splitsize) must be float, int or None ({type(splitsize)} detected.)" + raise TypeError(msg) # raise exception if no training data or negative validation size if n_split >= full_size or n_split < 0: + msg = f"Invalid Split size: {n_split}.\n" + f"Split size must be a float between 0 and 1 OR an int smaller than the size of the dataset ({full_size} datapoints)" raise ValueError( - f"Invalid splitsize: {n_split}. splitsize must be a float between 0 and 1 OR an int smaller than the size of the dataset ({full_size} datapoints)" + msg, ) if splitsize == 0: # i.e. the fraction of splitsize was so small that it rounded to <1 datapoint diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index 1e49377bd..a6e9b3c63 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -15,7 +15,7 @@ _log = logging.getLogger(__name__) -def _add_atom_to_residue(atom: Atom, residue: Residue): +def _add_atom_to_residue(atom: Atom, residue: Residue) -> None: """Adds an `Atom` to a `Residue` if not already there. If no matching atom is found, add the current atom to the residue. @@ -31,8 +31,8 @@ def _add_atom_to_residue(atom: Atom, residue: Residue): def _add_atom_data_to_structure( structure: PDBStructure, pdb_obj: pdb2sql_object, - **kwargs, -): + **kwargs, # noqa: ANN003 +) -> None: """This subroutine retrieves pdb2sql atomic data for `PDBStructure` objects as defined in DeepRank2. This function should be called for one atom at a time. @@ -111,7 +111,7 @@ def get_contact_atoms( pdb_rowID = atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] _add_atom_data_to_structure(structure, interface, rowID=pdb_rowID) finally: - interface._close() # noqa: SLF001 (private-member-access) + interface._close() # noqa: SLF001 return structure.get_atoms() @@ -145,7 +145,7 @@ def get_residue_contact_pairs( return_contact_pairs=True, ) finally: - interface._close() # noqa: SLF001 (private-member-access) + interface._close() # noqa: SLF001 # Map to residue objects residue_pairs = set() @@ -169,7 +169,8 @@ def _get_residue_from_key( for residue in chain.residues: if residue.number == residue_number and residue.amino_acid is not None and residue.amino_acid.three_letter_code == residue_name: return residue - raise ValueError(f"Residue ({residue_key}) not found in {structure.id}.") + msg = f"Residue ({residue_key}) not found in {structure.id}." + raise ValueError(msg) def get_surrounding_residues( diff --git a/deeprank2/utils/community_pooling.py b/deeprank2/utils/community_pooling.py index 747f77c32..04384268f 100644 --- a/deeprank2/utils/community_pooling.py +++ b/deeprank2/utils/community_pooling.py @@ -11,14 +11,16 @@ from torch_geometric.nn.pool.pool import pool_batch, pool_edge from torch_scatter import scatter_max, scatter_mean +# ruff: noqa: ANN001, ANN201 (missing type hints and return types) -def plot_graph(graph, cluster): + +def plot_graph(graph, cluster) -> None: # noqa:D103 pos = nx.spring_layout(graph, iterations=200) nx.draw(graph, pos, node_color=cluster) plt.show() -def get_preloaded_cluster(cluster, batch): +def get_preloaded_cluster(cluster, batch): # noqa:D103 nbatch = torch.max(batch) + 1 for ib in range(1, nbatch): cluster[batch == ib] += torch.max(cluster[batch == ib - 1]) + 1 @@ -85,7 +87,8 @@ def community_detection_per_batch( ncluster = max(cluster) else: - raise ValueError(f"Clustering method {method} not supported") + msg = f"Clustering method {method} not supported" + raise ValueError(msg) device = edge_index.device return torch.tensor(cluster).to(device) @@ -155,7 +158,8 @@ def community_detection( return torch.tensor(index).to(device) - raise ValueError(f"Clustering method {method} not supported") + msg = f"Clustering method {method} not supported" + raise ValueError(msg) def community_pooling(cluster, data): diff --git a/deeprank2/utils/earlystopping.py b/deeprank2/utils/earlystopping.py index 02e9bffd7..7b08c6605 100644 --- a/deeprank2/utils/earlystopping.py +++ b/deeprank2/utils/earlystopping.py @@ -2,6 +2,25 @@ class EarlyStopping: + """Terminate training upon trigger. + + Triggered if validation loss doesn't improve after a given patience or if a maximum gap between validation and training loss is reached. + + Args: + patience (int, optional): How long to wait after last time validation loss improved. + Defaults to 10. + delta (float, optional): Minimum change required to reset the early stopping counter. + Defaults to 0. + maxgap (float, optional): Maximum difference between between training and validation loss. + Defaults to None. + min_epoch (float, optional): Minimum epoch to be reached before looking at maxgap. + Defaults to 10. + verbose (bool, optional): If True, prints a message for each validation loss improvement. + Defaults to True. + trace_func (Callable, optional): Function used for recording EarlyStopping status. + Defaults to print. + """ + def __init__( self, patience: int = 10, @@ -11,24 +30,6 @@ def __init__( verbose: bool = True, trace_func: Callable = print, ): - """Terminate training upon trigger. - - Triggered if validation loss doesn't improve after a given patience or if a maximum gap between validation and training loss is reached. - - Args: - patience (int, optional): How long to wait after last time validation loss improved. - Defaults to 10. - delta (float, optional): Minimum change required to reset the early stopping counter. - Defaults to 0. - maxgap (float, optional): Maximum difference between between training and validation loss. - Defaults to None. - min_epoch (float, optional): Minimum epoch to be reached before looking at maxgap. - Defaults to 10. - verbose (bool, optional): If True, prints a message for each validation loss improvement. - Defaults to True. - trace_func (Callable, optional): Function used for recording EarlyStopping status. - Defaults to print. - """ self.patience = patience self.delta = delta self.maxgap = maxgap @@ -41,7 +42,12 @@ def __init__( self.best_score = None self.val_loss_min = None - def __call__(self, epoch, val_loss, train_loss=None): + def __call__( # noqa: C901 + self, + epoch: int, + val_loss: float, + train_loss: float | None = None, + ): score = -val_loss # initialize @@ -53,13 +59,12 @@ def __call__(self, epoch, val_loss, train_loss=None): elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: + extra_trace = "" if self.delta: extra_trace = f"more than {self.delta} " - else: - extra_trace = "" self.trace_func( f"Validation loss did not decrease {extra_trace}({self.val_loss_min:.6f} --> {val_loss:.6f}). " - f"EarlyStopping counter: {self.counter} out of {self.patience}" + f"EarlyStopping counter: {self.counter} out of {self.patience}", ) if self.counter >= self.patience: self.trace_func(f"EarlyStopping activated at epoch # {epoch} because patience of {self.patience} has been reached.") @@ -77,12 +82,13 @@ def __call__(self, epoch, val_loss, train_loss=None): # check maxgap if self.maxgap and epoch > self.min_epoch: if train_loss is None: - raise ValueError("Cannot compute gap because no train_loss is provided to EarlyStopping.") + msg = "Cannot compute gap because no train_loss is provided to EarlyStopping." + raise ValueError(msg) gap = val_loss - train_loss if gap > self.maxgap: self.trace_func( f"EarlyStopping activated at epoch # {epoch} due to overfitting. " - f"The difference between validation and training loss of {gap} exceeds the maximum allowed ({self.maxgap})" + f"The difference between validation and training loss of {gap} exceeds the maximum allowed ({self.maxgap})", ) self.early_stop = True diff --git a/deeprank2/utils/exporters.py b/deeprank2/utils/exporters.py index a262059c7..04756e145 100644 --- a/deeprank2/utils/exporters.py +++ b/deeprank2/utils/exporters.py @@ -28,7 +28,7 @@ def __enter__(self): """Overridable.""" return self - def __exit__(self, exception_type, exception, traceback): + def __exit__(self, exception_type, exception, traceback): # noqa: ANN001 """Overridable.""" def process( @@ -39,13 +39,13 @@ def process( output_values: list, target_values: list, loss: float, - ): + ) -> None: """The entry_names, output_values, target_values MUST have the same length.""" def is_compatible_with( self, - output_data_shape: int, # noqa: ARG002 (unused argument) - target_data_shape: int | None = None, # noqa: ARG002 (unused argument) + output_data_shape: int, # noqa: ARG002 + target_data_shape: int | None = None, # noqa: ARG002 ) -> bool: """True if this exporter can work with the given data shapes.""" return True @@ -63,7 +63,7 @@ def __enter__(self): return self - def __exit__(self, exception_type, exception, traceback): + def __exit__(self, exception_type, exception, traceback): # noqa: ANN001 for output_exporter in self._output_exporters: output_exporter.__exit__(exception_type, exception, traceback) @@ -75,7 +75,7 @@ def process( output_values: list, target_values: list, loss: float, - ): + ) -> None: for output_exporter in self._output_exporters: output_exporter.process( pass_name, @@ -108,7 +108,7 @@ def __enter__(self): self._writer.__enter__() return self - def __exit__(self, exception_type, exception, traceback): + def __exit__(self, exception_type, exception, traceback): # noqa: ANN001 self._writer.__exit__(exception_type, exception, traceback) def process( @@ -118,8 +118,8 @@ def process( entry_names: list[str], output_values: list, target_values: list, - loss: float, # noqa: ARG002 (unused argument) - ): + loss: float, # noqa: ARG002 + ) -> None: """Write to tensorboard.""" ce_loss = cross_entropy(tensor(output_values), tensor(target_values)).item() self._writer.add_scalar( @@ -137,25 +137,25 @@ def process( prediction_value = argmax(tensor(output_values[entry_index])) target_value = target_values[entry_index] - if prediction_value > 0.0 and target_value > 0.0: + if prediction_value > 0 and target_value > 0: tp += 1 - elif prediction_value <= 0.0 and target_value <= 0.0: + elif prediction_value <= 0 and target_value <= 0: tn += 1 - elif target_value <= 0.0 < prediction_value: + elif target_value <= 0 < prediction_value: fp += 1 - elif prediction_value <= 0.0 < target_value: + elif prediction_value <= 0 < target_value: fn += 1 mcc_numerator = tn * tp - fp * fn - if mcc_numerator == 0.0: + if mcc_numerator == 0: self._writer.add_scalar(f"{pass_name} MCC", 0.0, epoch_number) else: mcc_denominator = sqrt((tn + fn) * (fp + tp) * (tn + fp) * (fn + tp)) - if mcc_denominator != 0.0: + if mcc_denominator != 0: mcc = mcc_numerator / mcc_denominator self._writer.add_scalar(f"{pass_name} MCC", mcc, epoch_number) @@ -163,7 +163,7 @@ def process( self._writer.add_scalar(f"{pass_name} accuracy", accuracy, epoch_number) # for ROC curves to work, we need both class values in the set - if len(set(target_values)) == 2: + if len(set(target_values)) == 2: # noqa:PLR2004 roc_auc = roc_auc_score(target_values, probabilities) self._writer.add_scalar(f"{pass_name} ROC AUC", roc_auc, epoch_number) @@ -173,20 +173,21 @@ def is_compatible_with( target_data_shape: int | None = None, ) -> bool: """For regression, target data is needed and output data must be a list of two-dimensional values.""" - return output_data_shape == 2 and target_data_shape == 1 + return output_data_shape == 2 and target_data_shape == 1 # noqa:PLR2004 class ScatterPlotExporter(OutputExporter): - def __init__(self, directory_path: str, epoch_interval: int = 1): - """An output exporter that can make scatter plots, containing every single data point. + """An output exporter that can make scatter plots, containing every single data point. + + On the X-axis: targets values + On the Y-axis: output values - On the X-axis: targets values - On the Y-axis: output values + Args: + directory_path (str): Where to store the plots. + epoch_interval (int, optional): How often to make a plot, 5 means: every 5 epochs. Defaults to 1. + """ - Args: - directory_path (str): Where to store the plots. - epoch_interval (int, optional): How often to make a plot, 5 means: every 5 epochs. Defaults to 1. - """ + def __init__(self, directory_path: str, epoch_interval: int = 1): super().__init__(directory_path) self._epoch_interval = epoch_interval @@ -194,15 +195,15 @@ def __enter__(self): self._plot_data = {} return self - def __exit__(self, exception_type, exception, traceback): + def __exit__(self, exception_type, exception, traceback): # noqa: ANN001 self._plot_data.clear() - def get_filename(self, epoch_number): + def get_filename(self, epoch_number: int) -> str: """Returns the filename for the table.""" return os.path.join(self._directory_path, f"scatter-{epoch_number}.png") @staticmethod - def _get_color(pass_name): + def _get_color(pass_name: str) -> str: pass_name = pass_name.lower().strip() if pass_name in ("train", "training"): return "blue" @@ -217,7 +218,7 @@ def _plot( epoch_number: int, data: dict[str, tuple[list[float], list[float]]], png_path: str, - ): + ) -> None: plt.title(f"Epoch {epoch_number}") for pass_name, (truth_values, prediction_values) in data.items(): @@ -239,11 +240,11 @@ def process( self, pass_name: str, epoch_number: int, - entry_names: list[str], # noqa: ARG002 (unused argument) + entry_names: list[str], # noqa: ARG002 output_values: list, target_values: list, - loss: float, # noqa: ARG002 (unused argument) - ): + loss: float, # noqa: ARG002 + ) -> None: """Make the plot, if the epoch matches with the interval.""" if epoch_number % self._epoch_interval == 0: if epoch_number not in self._plot_data: @@ -296,7 +297,7 @@ def __enter__(self): return self - def __exit__(self, exception_type, exception, traceback): + def __exit__(self, exception_type, exception, traceback): # noqa: ANN001 if self.phase is not None: if self.phase == "validation": self.phase = "training" @@ -315,7 +316,7 @@ def process( output_values: list, target_values: list, loss: float, - ): + ) -> None: self.phase = pass_name pass_name = [pass_name] * len(output_values) loss = [loss] * len(output_values) diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index a30832775..e131421bb 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -26,11 +26,13 @@ class Edge: + """Graph edge.""" + def __init__(self, id_: Contact): self.id = id_ self.features = {} - def add_feature(self, feature_name: str, feature_function: Callable[[Contact], float]): + def add_feature(self, feature_name: str, feature_function: Callable[[Contact], float]) -> None: feature_value = feature_function(self.id) self.features[feature_name] = feature_value @@ -49,6 +51,8 @@ def has_nan(self) -> bool: class Node: + """Graph node.""" + def __init__(self, id_: Atom | Residue): if isinstance(id_, Atom): self._type = "atom" @@ -61,7 +65,7 @@ def __init__(self, id_: Atom | Residue): self.features = {} @property - def type(self): + def type(self) -> str: return self._type def has_nan(self) -> bool: @@ -72,12 +76,13 @@ def add_feature( self, feature_name: str, feature_function: Callable[[Atom | Residue], NDArray], - ): + ) -> None: feature_value = feature_function(self.id) if len(feature_value.shape) != 1: shape_s = "x".join(feature_value.shape) - raise ValueError(f"Expected a 1-dimensional array for feature {feature_name}, but got {shape_s}") + msg = f"Expected a 1-dimensional array for feature {feature_name}, but got {shape_s}" + raise ValueError(msg) self.features[feature_name] = feature_value @@ -87,6 +92,8 @@ def position(self) -> np.array: class Graph: + """Graph.""" + def __init__(self, id_: str): self.id = id_ @@ -99,13 +106,13 @@ def __init__(self, id_: str): # the center only needs to be set when this graph should be mapped to a grid. self.center = np.array((0.0, 0.0, 0.0)) - def add_node(self, node: Node): + def add_node(self, node: Node) -> None: self._nodes[node.id] = node def get_node(self, id_: Atom | Residue) -> Node: return self._nodes[id_] - def add_edge(self, edge: Edge): + def add_edge(self, edge: Edge) -> None: self._edges[edge.id] = edge def get_edge(self, id_: Contact) -> Edge: @@ -134,7 +141,7 @@ def _map_point_features( points: list[NDArray], values: list[float | NDArray], augmentation: Augmentation | None = None, - ): + ) -> None: points = np.stack(points, axis=0) if augmentation is not None: @@ -156,7 +163,7 @@ def map_to_grid( grid: Grid, method: MapMethod, augmentation: Augmentation | None = None, - ): + ) -> None: # order edge features by xyz point points = [] feature_values = {} @@ -164,7 +171,7 @@ def map_to_grid( points += [edge.position1, edge.position2] for feature_name, feature_value in edge.features.items(): - feature_values[feature_name] = feature_values.get(feature_name, []) + [ # noqa: RUF005 (collection-literal-concatenation) + feature_values[feature_name] = feature_values.get(feature_name, []) + [ # noqa: RUF005 feature_value, feature_value, ] @@ -187,7 +194,7 @@ def map_to_grid( points.append(node.position) for feature_name, feature_value in node.features.items(): - feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value] # noqa: RUF005 (collection-literal-concatenation) + feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value] # noqa: RUF005 # map node features to grid for feature_name, values in feature_values.items(): @@ -200,7 +207,7 @@ def map_to_grid( augmentation, ) - def write_to_hdf5(self, hdf5_path: str): + def write_to_hdf5(self, hdf5_path: str) -> None: """Write a featured graph to an hdf5 file, according to deeprank standards.""" with h5py.File(hdf5_path, "a") as hdf5_file: # create groups to hold data @@ -344,7 +351,8 @@ def build_graph( atoms_residues = np.array(atoms_residues) NodeContact = ResidueContact else: - raise TypeError("All nodes in the graph must be of the same type.") + msg = "All nodes in the graph must be of the same type." + raise TypeError(msg) positions = np.empty((len(atoms), 3)) for atom_index, atom in enumerate(atoms): diff --git a/deeprank2/utils/grid.py b/deeprank2/utils/grid.py index 2c8f8a55f..70e7bda6f 100644 --- a/deeprank2/utils/grid.py +++ b/deeprank2/utils/grid.py @@ -57,8 +57,9 @@ def __init__( points_counts: list[int], sizes: list[float], ): - if len(points_counts) != 3 or len(sizes) != 3: - raise ValueError("Incorrect grid dimensions.") + if len(points_counts) != 3 or len(sizes) != 3: # noqa:PLR2004 + msg = "Incorrect grid dimensions." + raise ValueError(msg) self._points_counts = points_counts self._sizes = sizes @@ -93,7 +94,7 @@ def __init__(self, id_: str, center: list[float], settings: GridSettings): self._set_mesh(self._center, settings) self._features = {} - def _set_mesh(self, center: NDArray, settings: GridSettings): + def _set_mesh(self, center: NDArray, settings: GridSettings) -> None: """Builds the grid points.""" half_size_x = settings.sizes[0] / 2 half_size_y = settings.sizes[1] / 2 @@ -145,7 +146,7 @@ def zgrid(self) -> NDArray: def features(self) -> dict[str, NDArray]: return self._features - def add_feature_values(self, feature_name: str, data: NDArray): + def add_feature_values(self, feature_name: str, data: NDArray) -> None: """Makes sure feature values per grid point get stored. This method may be called repeatedly to add on to existing grid point values. @@ -269,7 +270,7 @@ def map_feature( feature_name: str, feature_value: NDArray | float, method: MapMethod, - ): + ) -> None: """Maps point feature data at a given position to the grid, using the given method. The feature_value should either be a single number or a one-dimensional array. @@ -304,7 +305,7 @@ def map_feature( # set to grid self.add_feature_values(index_name, grid_data) - def to_hdf5(self, hdf5_path: str): + def to_hdf5(self, hdf5_path: str) -> None: """Write the grid data to hdf5, according to deeprank standards.""" with h5py.File(hdf5_path, "a") as hdf5_file: # create a group to hold everything diff --git a/deeprank2/utils/parsing/__init__.py b/deeprank2/utils/parsing/__init__.py index 8211b9e80..1d253eee0 100644 --- a/deeprank2/utils/parsing/__init__.py +++ b/deeprank2/utils/parsing/__init__.py @@ -13,7 +13,7 @@ _forcefield_directory_path = os.path.realpath(os.path.join(os.path.dirname(__file__), "../../domain/forcefield")) -class AtomicForcefield: +class AtomicForcefield: # noqa: D101 def __init__(self): top_path = os.path.join(_forcefield_directory_path, "protein-allhdg5-5_new.top") with open(top_path, encoding="utf-8") as f: @@ -31,7 +31,7 @@ def __init__(self): with open(param_path, encoding="utf-8") as f: self._vanderwaals_parameters = ParamParser.parse(f) - def _find_matching_residue_class(self, residue: Residue): + def _find_matching_residue_class(self, residue: Residue) -> str | None: for criterium in self._residue_class_criteria: if criterium.matches( residue.amino_acid.three_letter_code, @@ -41,7 +41,7 @@ def _find_matching_residue_class(self, residue: Residue): return None - def get_vanderwaals_parameters(self, atom: Atom): + def get_vanderwaals_parameters(self, atom: Atom) -> VanderwaalsParam: atom_name = atom.name if atom.residue.amino_acid is None: @@ -69,7 +69,7 @@ def get_vanderwaals_parameters(self, atom: Atom): return VanderwaalsParam(0.0, 0.0, 0.0, 0.0) return self._vanderwaals_parameters[type_] - def get_charge(self, atom: Atom): + def get_charge(self, atom: Atom) -> float: """Get the charge of a given `Atom`. Args: diff --git a/deeprank2/utils/parsing/patch.py b/deeprank2/utils/parsing/patch.py index 728ff94db..4c0c27163 100644 --- a/deeprank2/utils/parsing/patch.py +++ b/deeprank2/utils/parsing/patch.py @@ -2,6 +2,8 @@ from enum import Enum from typing import Any +# ruff: noqa: D101 + class PatchActionType(Enum): MODIFY = 1 @@ -20,10 +22,10 @@ def __init__(self, type_: str, selection: PatchSelection, kwargs: dict[str, Any] self.selection = selection self.kwargs = kwargs - def __contains__(self, key): + def __contains__(self, key: str): return key in self.kwargs - def __getitem__(self, key): + def __getitem__(self, key: str): return self.kwargs[key] @@ -33,15 +35,16 @@ class PatchParser: ACTION_PATTERN = re.compile(r"^([A-Z]{3,4})\s+([A-Z]+)\s+ATOM\s+([A-Z0-9]{1,3})\s+(.*)$") @staticmethod - def _parse_action_type(s): + def _parse_action_type(s: str) -> PatchActionType: for type_ in PatchActionType: if type_.name == s: return type_ - raise ValueError(f"Unmatched residue action: {s!r}") + msg = f"Unmatched residue action: {s!r}" + raise ValueError(msg) @staticmethod - def parse(file_): + def parse(file_: str) -> list[PatchAction]: result = [] for line in file_: if line.startswith(("#", "!")) or len(line.strip()) == 0: @@ -49,7 +52,8 @@ def parse(file_): m = PatchParser.ACTION_PATTERN.match(line) if not m: - raise ValueError(f"Unmatched patch action: {line!r}") + msg = f"Unmatched patch action: {line!r}" + raise ValueError(msg) residue_type = m.group(1) action_type = PatchParser._parse_action_type(m.group(2)) diff --git a/deeprank2/utils/parsing/residue.py b/deeprank2/utils/parsing/residue.py index b37b734fe..49d7ddf7a 100644 --- a/deeprank2/utils/parsing/residue.py +++ b/deeprank2/utils/parsing/residue.py @@ -1,7 +1,7 @@ import re -class ResidueClassCriterium: +class ResidueClassCriterium: # noqa: D101 def __init__( self, class_name: str, @@ -33,24 +33,25 @@ def matches(self, amino_acid_name: str, atom_names: list[str]) -> bool: return True -class ResidueClassParser: +class ResidueClassParser: # noqa: D101 _RESIDUE_CLASS_PATTERN = re.compile(r"([A-Z]{3,4}) *\: *name *\= *(all|[A-Z]{3})") _RESIDUE_ATOMS_PATTERN = re.compile(r"(present|absent)\(([A-Z0-9\, ]+)\)") @staticmethod - def parse(file_): + def parse(file_: str) -> list[ResidueClassCriterium]: result = [] for line in file_: match = ResidueClassParser._RESIDUE_CLASS_PATTERN.match(line) if not match: - raise ValueError(f"Unparsable line: '{line}'") + msg = f"Unparsable line: '{line}'" + raise ValueError(msg) class_name = match.group(1) amino_acid_names = ResidueClassParser._parse_amino_acids(match.group(2)) present_atom_names = [] absent_atom_names = [] - for match in ResidueClassParser._RESIDUE_ATOMS_PATTERN.finditer(line[match.end() :]): # noqa: B020 (loop-variable-overrides-iterator) + for match in ResidueClassParser._RESIDUE_ATOMS_PATTERN.finditer(line[match.end() :]): # noqa: B020 atom_names = [name.strip() for name in match.group(2).split(",")] if match.group(1) == "present": present_atom_names.extend(atom_names) @@ -62,7 +63,7 @@ def parse(file_): return result @staticmethod - def _parse_amino_acids(string: str): + def _parse_amino_acids(string: str) -> str | list[str]: if string.strip() == "all": return string.strip() return [name.strip() for name in string.split(",")] diff --git a/deeprank2/utils/parsing/top.py b/deeprank2/utils/parsing/top.py index 84b6f8fa1..349937535 100644 --- a/deeprank2/utils/parsing/top.py +++ b/deeprank2/utils/parsing/top.py @@ -5,7 +5,7 @@ logging.getLogger(__name__) -class TopRowObject: +class TopRowObject: # noqa: D101 def __init__( self, residue_name: str, @@ -16,23 +16,24 @@ def __init__( self.atom_name = atom_name self.kwargs = kwargs - def __getitem__(self, key): + def __getitem__(self, key: str): return self.kwargs[key] -class TopParser: +class TopParser: # noqa: D101 _VAR_PATTERN = re.compile(r"([^\s]+)\s*=\s*([^\s\(\)]+|\(.*\))") _LINE_PATTERN = re.compile(r"^([A-Z0-9]{3})\s+atom\s+([A-Z0-9]{1,4})\s+(.+)\s+end\s*(\s+\!\s+[ _A-Za-z0-9]+)?$") _NUMBER_PATTERN = re.compile(r"\-?[0-9]+(\.[0-9]+)?") @staticmethod - def parse(file_): + def parse(file_: str) -> list[TopRowObject]: result = [] for line in file_: # parse the line m = TopParser._LINE_PATTERN.match(line) if not m: - raise ValueError(f"Unmatched top line: {line}") + msg = f"Unmatched top line: {line}" + raise ValueError(msg) residue_name = m.group(1).upper() atom_name = m.group(2).upper() @@ -46,7 +47,7 @@ def parse(file_): return result @staticmethod - def _parse_value(s): + def _parse_value(s: str) -> float | str: # remove parentheses if s[0] == "(" and s[-1] == ")": return TopParser._parse_value(s[1:-1]) diff --git a/deeprank2/utils/parsing/vdwparam.py b/deeprank2/utils/parsing/vdwparam.py index 73a594436..07e77b75c 100644 --- a/deeprank2/utils/parsing/vdwparam.py +++ b/deeprank2/utils/parsing/vdwparam.py @@ -1,4 +1,4 @@ -class VanderwaalsParam: +class VanderwaalsParam: # noqa: D101 def __init__( self, epsilon_main: float, @@ -15,9 +15,9 @@ def __str__(self) -> str: return f"{self.epsilon_main}, {self.sigma_main}, {self.epsilon_14}, {self.sigma_14}" -class ParamParser: +class ParamParser: # noqa: D101 @staticmethod - def parse(file_): + def parse(file_: str) -> dict[str, VanderwaalsParam]: result = {} for line in file_: if line.startswith("#"): @@ -42,6 +42,7 @@ def parse(file_): elif len(line.strip()) == 0: continue else: - raise ValueError(f"Unparsable param line: {line}") + msg = f"Unparsable param line: {line}" + raise ValueError(msg) return result diff --git a/deeprank2/utils/pssmdata.py b/deeprank2/utils/pssmdata.py index 49134d1ac..15bc89fc7 100644 --- a/deeprank2/utils/pssmdata.py +++ b/deeprank2/utils/pssmdata.py @@ -1,4 +1,7 @@ +from typing_extensions import Self + from deeprank2.molstruct.aminoacid import AminoAcid +from deeprank2.molstruct.residue import Residue class PssmRow: @@ -33,12 +36,12 @@ def __init__(self, rows: list[PssmRow] | None = None): else: self._rows = rows - def __contains__(self, residue) -> bool: + def __contains__(self, residue: Residue) -> bool: return residue in self._rows - def __getitem__(self, residue) -> PssmRow: + def __getitem__(self, residue: Residue) -> PssmRow: return self._rows[residue] - def update(self, other): + def update(self, other: Self) -> None: """Can be used to merge two non-overlapping scoring tables.""" - self._rows.update(other._rows) # noqa: SLF001 (private-member-access) + self._rows.update(other._rows) # noqa: SLF001 diff --git a/docs/conf.py b/docs/conf.py index 35f734134..5ed177fa5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,14 +13,15 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import configparser # noqa: F401 (unused-import) +import configparser + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import sys -import toml +import toml # pyright: ignore[reportMissingModuleSource] autodoc_mock_imports = [ "numpy", diff --git a/pyproject.toml b/pyproject.toml index 27e309a18..6482eb385 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,113 +85,71 @@ exclude = ["tests*", "*tests.*", "*tests"] "*" = ["*.xlsx", "*.param", "*.top", "*residue-classes"] [tool.ruff] -# Exclude a variety of commonly ignored directories. -extend-exclude = ["docs", "reduce"] line-length = 159 -select = [ - "F", # Pyflakes - "E", # pycodestyle (error) - "W", # pycodestyle (warning) - "I", # isort - "D", # pydocstyle - "UP", # pyupgrade - "SIM", # simplify - "C4", # flake8-comprehensions - "S", # flake8-bandit - "PGH", # pygrep-hooks - "BLE", # blind-except - "FBT003", # boolean-positional-value-in-call - "B", # flake8-bugbear - "Q", # flake8-quotes - "PLR", # pylint refactoring - "ARG", # flake8-unused-arguments - "SLF001", # Private member accessed - "PIE", # flake8-pie - "RET", # flaske8-return - "PT", # pytest - "TID", # imports - "TCH", # imports - "PD", # pandas - "NPY", # numpy - "PL", # pylint - "RUF", # ruff rtecommendations - "PERF", # performance - "TRY", # try blocks - "ERA", # commented out code - # other linting conventions - "FLY", - "AIR", - "YTT", - "ASYNC", - "A", - "DTZ", - "DJ", - "FA", - "ISC", - "ICN", - "G", - "INP", - "PYI", - "Q", - "RSE102", - "SLOT", - "INT", - # The following are unrealistic for this code base - # "PTH" # flake8-use-pathlib - # "ANN", # annotations - # "N", # naming conventions - # "C901", # mccabe complexity -] +select = ["ALL"] ignore = [ + # Unrealistic for this code base + "PTH", # flake8-use-pathlib + "N", # naming conventions "PLR0912", # Too many branches, - "PLR0913", #Too many arguments in function definition - "B028", # No explicit `stacklevel` keyword argument found in - "PLR2004", # Magic value used in comparison - "S105", # Possible hardcoded password - "S311", # insecure random generators - "PT011", # pytest-raises-too-broad - "SIM108", # Use ternary operator - "TRY003", # Long error messages - # Missing docstrings Documentation + "PLR0913", # Too many arguments in function definition + "D102", # Missing docstring in public method + # Unwanted + "FBT", # Using boolean arguments + "ANN101", # Missing type annotation for `self` in method + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN204", # Missing return type annotation for special (dunder) method + "B028", # No explicit `stacklevel` keyword argument found in warning + "S105", # Possible hardcoded password + "S311", # insecure random generators + "PT011", # pytest-raises-too-broad + "SIM108", # Use ternary operator + # Unwanted docstrings "D100", # Missing module docstring - "D101", # Missing docstring in public class - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function "D104", # Missing public package docstring "D105", # Missing docstring in magic method "D107", # Missing docstring in `__init__` - # Rules irrelevant to the Google style + # Docstring rules irrelevant to the Google style "D203", # 1 blank line required before class docstring - "D204", + "D204", # 1 blank line required after class docstring "D212", # Multi-line docstring summary should start at the first line "D213", # Multi-line docstring summary should start at the second line - "D215", - "D400", - "D401", + "D215", # Section underline is over-indented + "D400", # First line should end with a period (clashes with D415:First line should end with a period, question mark, or exclamation point) + "D401", # First line of docstring should be in imperative mood "D404", # First word of the docstring should not be This - "D406", - "D407", - "D408", - "D409", - "D413", + "D406", # Section name should end with a newline + "D407", # Missing dashed underline after section + "D408", # Section underline should be in the line following the section's name + "D409", # Section underline should match the length of its name + "D413", # Missing blank line after last section ] -# Allow autofix for all enabled rules (when `--fix`) is provided. +# Allow autofix for all enabled rules. fixable = ["ALL"] -unfixable = [ - "F401", -] # unused imports (it's annoying if they automatically disappear while editing code +unfixable = ["F401"] # unused imports (should not disappear while editing) [tool.ruff.lint.per-file-ignores] -"tests/*" = ["S101"] +"tests/*" = [ + "S101", # Use of `assert` detected + "PLR2004", # Magic value used in comparison + "D101", # Missing class docstring + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function +] +"docs/*" = ["ALL"] +"tests/perf/*" = ["T201"] # Use of print statements [tool.ruff.lint] extend-safe-fixes = [ "D415", # First line should end with a period, question mark, or exclamation point "D300", # Use triple double quotes `"""` "D200", # One-line docstring should fit on one line - "TCH", # type checking only imports - "ISC001", + "TCH", # Format type checking only imports + "ISC001", # Implicitly concatenated strings on a single line + "EM", # Exception message variables + "RUF013", # Implicit Optional + "B006", # Mutable default argument ] [tool.ruff.isort] diff --git a/tests/domain/test_aminoacidlist.py b/tests/domain/test_aminoacidlist.py index f734edd36..736ddcb36 100644 --- a/tests/domain/test_aminoacidlist.py +++ b/tests/domain/test_aminoacidlist.py @@ -11,7 +11,7 @@ ] -def test_all_different_onehot(): +def test_all_different_onehot() -> None: for aa1, aa2 in zip(amino_acids, amino_acids, strict=True): if aa1 == aa2: continue @@ -22,4 +22,5 @@ def test_all_different_onehot(): if (aa1 in EXCEPTIONS[0] and aa2 in EXCEPTIONS[0]) or (aa1 in EXCEPTIONS[1] and aa2 in EXCEPTIONS[1]): assert np.all(aa1.onehot == aa2.onehot) else: - raise AssertionError(f"One-hot index {aa1.index} is occupied by both {aa1} and {aa2}") from e + msg = f"One-hot index {aa1.index} is occupied by both {aa1} and {aa2}" + raise AssertionError(msg) from e diff --git a/tests/domain/test_forcefield.py b/tests/domain/test_forcefield.py index ef4196c3a..4db2c6db7 100644 --- a/tests/domain/test_forcefield.py +++ b/tests/domain/test_forcefield.py @@ -5,12 +5,12 @@ from deeprank2.utils.parsing import atomic_forcefield -def test_atomic_forcefield(): +def test_atomic_forcefield() -> None: pdb = pdb2sql("tests/data/pdb/101M/101M.pdb") try: structure = get_structure(pdb, "101M") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 # The arginine C-zeta should get a positive charge arg = next(r for r in structure.get_chain("A").residues if r.amino_acid == arginine) diff --git a/tests/features/__init__.py b/tests/features/__init__.py index 8521d2b35..541e5c777 100644 --- a/tests/features/__init__.py +++ b/tests/features/__init__.py @@ -16,10 +16,11 @@ def _get_residue(chain: Chain, number: int) -> Residue: for residue in chain.residues: if residue.number == number: return residue - raise ValueError(f"Not found: {number}") + msg = f"Not found: {number}" + raise ValueError(msg) -def build_testgraph( +def build_testgraph( # noqa: C901 pdb_path: str, detail: Literal["atom", "residue"], influence_radius: float, @@ -54,7 +55,7 @@ def build_testgraph( try: structure: PDBStructure = get_structure(pdb, Path(pdb_path).stem) finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 if not central_res: nodes = set() @@ -73,7 +74,8 @@ def build_testgraph( for atom in residue2.atoms: nodes.add(atom) else: - raise TypeError('detail must be "atom" or "residue"') + msg = 'detail must be "atom" or "residue"' + raise TypeError(msg) return Graph.build_graph(list(nodes), structure.id, max_edge_length), None @@ -99,4 +101,5 @@ def build_testgraph( if detail == "atom": atoms = {atom for residue in surrounding_residues for atom in residue.atoms} return Graph.build_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) - raise TypeError('detail must be "atom" or "residue"') + msg = 'detail must be "atom" or "residue"' + raise TypeError(msg) diff --git a/tests/features/test_components.py b/tests/features/test_components.py index f00b91c41..88f47dbd1 100644 --- a/tests/features/test_components.py +++ b/tests/features/test_components.py @@ -7,7 +7,7 @@ from . import build_testgraph -def test_atom_features(): +def test_atom_features() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, @@ -21,7 +21,7 @@ def test_atom_features(): assert not any(np.isnan(node.features[Nfeat.PDBOCCUPANCY]) for node in graph.nodes) -def test_aminoacid_features(): +def test_aminoacid_features() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, diff --git a/tests/features/test_conservation.py b/tests/features/test_conservation.py index a15a43bfd..03ff82da9 100644 --- a/tests/features/test_conservation.py +++ b/tests/features/test_conservation.py @@ -8,7 +8,7 @@ from . import build_testgraph -def test_conservation_residue(): +def test_conservation_residue() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, @@ -29,7 +29,7 @@ def test_conservation_residue(): assert np.any([node.features[feature_name] != 0.0 for node in graph.nodes]), f"all 0s found for {feature_name}" -def test_conservation_atom(): +def test_conservation_atom() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, @@ -50,7 +50,7 @@ def test_conservation_atom(): assert np.any([node.features[feature_name] != 0.0 for node in graph.nodes]), f"all 0s found for {feature_name}" -def test_no_pssm_file_error(): +def test_no_pssm_file_error() -> None: pdb_path = "tests/data/pdb/1CRN/1CRN.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, diff --git a/tests/features/test_contact.py b/tests/features/test_contact.py index ab1ef06f5..d9e9f001a 100644 --- a/tests/features/test_contact.py +++ b/tests/features/test_contact.py @@ -18,10 +18,11 @@ def _get_atom(chain: Chain, residue_number: int, atom_name: str) -> Atom: for atom in residue.atoms: if atom.name == atom_name: return atom - raise ValueError(f"Not found: chain {chain.id} residue {residue_number} atom {atom_name}") + msg = f"Not found: chain {chain.id} residue {residue_number} atom {atom_name}" + raise ValueError(msg) -def _wrap_in_graph(edge: Edge): +def _wrap_in_graph(edge: Edge) -> Graph: g = Graph(uuid4().hex) g.add_edge(edge) return g @@ -42,7 +43,7 @@ def _get_contact( try: structure = get_structure(pdb, pdb_id) finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 if not chains: chains = [structure.chains[0], structure.chains[0]] @@ -71,7 +72,7 @@ def _get_contact( return edge_obj -def test_covalent_pair(): +def test_covalent_pair() -> None: """MET 0: N - CA, covalent pair (at 1.49 A distance). Should have 0 vanderwaals and electrostatic energies.""" edge_covalent = _get_contact("101M", 0, "N", 0, "CA") assert edge_covalent.features[Efeat.DISTANCE] < covalent_cutoff @@ -80,7 +81,7 @@ def test_covalent_pair(): assert edge_covalent.features[Efeat.COVALENT] == 1.0, "covalent pair not recognized as covalent" -def test_13_pair(): +def test_13_pair() -> None: """MET 0: N - CB, 1-3 pair (at 2.47 A distance). Should have 0 vanderwaals and electrostatic energies.""" edge_13 = _get_contact("101M", 0, "N", 0, "CB") assert edge_13.features[Efeat.DISTANCE] < cutoff_13 @@ -89,7 +90,7 @@ def test_13_pair(): assert edge_13.features[Efeat.COVALENT] == 0.0, "1-3 pair recognized as covalent" -def test_very_close_opposing_chains(): +def test_very_close_opposing_chains() -> None: """ChainA THR 118 O - ChainB ARG 30 NH1 (3.55 A). Should have non-zero energy despite close contact, because opposing chains.""" opposing_edge = _get_contact("1A0Z", 118, "O", 30, "NH1", chains=("A", "B")) assert opposing_edge.features[Efeat.DISTANCE] < cutoff_13 @@ -97,7 +98,7 @@ def test_very_close_opposing_chains(): assert opposing_edge.features[Efeat.VDW] != 0.0 -def test_14_pair(): +def test_14_pair() -> None: """MET 0: N - CG, 1-4 pair (at 4.12 A distance). Should have non-zero electrostatic energy and small non-zero vdw energy.""" edge_14 = _get_contact("101M", 0, "CA", 0, "SD") assert edge_14.features[Efeat.DISTANCE] > cutoff_13 @@ -108,7 +109,7 @@ def test_14_pair(): assert edge_14.features[Efeat.COVALENT] == 0.0, "1-4 pair recognized as covalent" -def test_14dist_opposing_chains(): +def test_14dist_opposing_chains() -> None: """ChainA PRO 114 CA - ChainB HIS 116 CD2 (3.62 A). Should have non-zero energy despite close contact, because opposing chains. @@ -122,20 +123,20 @@ def test_14dist_opposing_chains(): assert opposing_edge.features[Efeat.VDW] > 0.1, f"vdw: {opposing_edge.features[Efeat.VDW]}" -def test_vanderwaals_negative(): +def test_vanderwaals_negative() -> None: """MET 0 N - ASP 27 CB, very far (29.54 A). Should have negative vanderwaals energy.""" edge_far = _get_contact("101M", 0, "N", 27, "CB") assert edge_far.features[Efeat.VDW] < 0.0 -def test_vanderwaals_morenegative(): +def test_vanderwaals_morenegative() -> None: """MET 0 N - PHE 138 CG, intermediate distance (12.69 A). Should have more negative vanderwaals energy than the far interaction.""" edge_intermediate = _get_contact("101M", 0, "N", 138, "CG") edge_far = _get_contact("101M", 0, "N", 27, "CB") assert edge_intermediate.features[Efeat.VDW] < edge_far.features[Efeat.VDW] -def test_edge_distance(): +def test_edge_distance() -> None: """Check the edge distances.""" edge_close = _get_contact("101M", 0, "N", 0, "CA") edge_intermediate = _get_contact("101M", 0, "N", 138, "CG") @@ -145,13 +146,13 @@ def test_edge_distance(): assert edge_far.features[Efeat.DISTANCE] > edge_intermediate.features[Efeat.DISTANCE], "far distance < intermediate distance" -def test_attractive_electrostatic_close(): +def test_attractive_electrostatic_close() -> None: """ARG 139 CZ - GLU 136 OE2, very close (5.60 A). Should have attractive electrostatic energy.""" close_attracting_edge = _get_contact("101M", 139, "CZ", 136, "OE2") assert close_attracting_edge.features[Efeat.ELEC] < 0.0 -def test_attractive_electrostatic_far(): +def test_attractive_electrostatic_far() -> None: """ARG 139 CZ - ASP 20 OD2, far (24.26 A). Should have attractive more electrostatic energy than above.""" far_attracting_edge = _get_contact("101M", 139, "CZ", 20, "OD2") close_attracting_edge = _get_contact("101M", 139, "CZ", 136, "OE2") @@ -159,13 +160,13 @@ def test_attractive_electrostatic_far(): assert far_attracting_edge.features[Efeat.ELEC] > close_attracting_edge.features[Efeat.ELEC], "far electrostatic <= close electrostatic" -def test_repulsive_electrostatic(): +def test_repulsive_electrostatic() -> None: """GLU 109 OE2 - GLU 105 OE1 (9.64 A). Should have repulsive electrostatic energy.""" opposing_edge = _get_contact("101M", 109, "OE2", 105, "OE1") assert opposing_edge.features[Efeat.ELEC] > 0.0 -def test_residue_contact(): +def test_residue_contact() -> None: """Check that we can calculate features for residue contacts.""" res_edge = _get_contact("101M", 0, "", 1, "", residue_level=True) assert res_edge.features[Efeat.DISTANCE] > 0.0, "distance <= 0" diff --git a/tests/features/test_exposure.py b/tests/features/test_exposure.py index 39d7129ca..1389f9fa8 100644 --- a/tests/features/test_exposure.py +++ b/tests/features/test_exposure.py @@ -7,13 +7,13 @@ from . import build_testgraph -def _run_assertions(graph: Graph): +def _run_assertions(graph: Graph) -> None: assert np.any(node.features[Nfeat.HSE] != 0.0 for node in graph.nodes), "hse" assert np.any(node.features[Nfeat.RESDEPTH] != 0.0 for node in graph.nodes), "resdepth" -def test_exposure_residue(): +def test_exposure_residue() -> None: pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, @@ -25,7 +25,7 @@ def test_exposure_residue(): _run_assertions(graph) -def test_exposure_atom(): +def test_exposure_atom() -> None: pdb_path = "tests/data/pdb/1ak4/1ak4.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, diff --git a/tests/features/test_irc.py b/tests/features/test_irc.py index 3505d52b4..e5cb26fcc 100644 --- a/tests/features/test_irc.py +++ b/tests/features/test_irc.py @@ -7,16 +7,16 @@ from . import build_testgraph -def _run_assertions(graph: Graph): +def _run_assertions(graph: Graph) -> None: assert not np.any([np.isnan(node.features[Nfeat.IRCTOTAL]) for node in graph.nodes]), "nan found" assert np.any([node.features[Nfeat.IRCTOTAL] > 0 for node in graph.nodes]), "no contacts" assert np.all( - [node.features[Nfeat.IRCTOTAL] == sum(node.features[IRCtype] for IRCtype in Nfeat.IRC_FEATURES[:-1]) for node in graph.nodes] + [node.features[Nfeat.IRCTOTAL] == sum(node.features[IRCtype] for IRCtype in Nfeat.IRC_FEATURES[:-1]) for node in graph.nodes], ), "incorrect total" -def test_irc_residue(): +def test_irc_residue() -> None: pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, @@ -28,7 +28,7 @@ def test_irc_residue(): _run_assertions(graph) -def test_irc_atom(): +def test_irc_atom() -> None: pdb_path = "tests/data/pdb/1A0Z/1A0Z.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, diff --git a/tests/features/test_secondary_structure.py b/tests/features/test_secondary_structure.py index 7d2ff1ff0..adc11a42e 100644 --- a/tests/features/test_secondary_structure.py +++ b/tests/features/test_secondary_structure.py @@ -12,7 +12,7 @@ from . import build_testgraph -def test_secondary_structure_residue(): +def test_secondary_structure_residue() -> None: test_case = "9api" # properly formatted pdb file pdb_path = f"tests/data/pdb/{test_case}/{test_case}.pdb" graph, _ = build_testgraph( @@ -27,7 +27,6 @@ def test_secondary_structure_residue(): # Create a list of node information (residue number, chain ID, and secondary structure features) node_info_list = [[node.id.number, node.id.chain.id, node.features[Nfeat.SECSTRUCT]] for node in graph.nodes] - print(node_info_list) # Check that all nodes have exactly 1 secondary structure type assert np.all([np.sum(node.features[Nfeat.SECSTRUCT]) == 1.0 for node in graph.nodes]), "one hot encoding error" @@ -37,26 +36,26 @@ def test_secondary_structure_residue(): (267, "A", " ", SecondarySctructure.COIL), (46, "A", "S", SecondarySctructure.COIL), (104, "A", "T", SecondarySctructure.COIL), - # (None, '', 'P', SecondarySctructure.COIL), # not found in test file # noqa: ERA001 (commented-out code) + # (None, '', 'P', SecondarySctructure.COIL), # not found in test file # noqa: ERA001 (194, "A", "B", SecondarySctructure.STRAND), (385, "B", "E", SecondarySctructure.STRAND), (235, "A", "G", SecondarySctructure.HELIX), (263, "A", "H", SecondarySctructure.HELIX), - # (0, '', 'I', SecondarySctructure.HELIX), # not found in test file # noqa: ERA001 (commented-out code) + # (0, '', 'I', SecondarySctructure.HELIX), # not found in test file # noqa: ERA001 ] for res in residues: node_list = [node_info for node_info in node_info_list if (node_info[0] == res[0] and node_info[1] == res[1])] assert len(node_list) > 0, f"no nodes detected in {res[1]} {res[0]}" assert np.all( - [np.array_equal(node_info[2], _classify_secstructure(res[2]).onehot) for node_info in node_list] + [np.array_equal(node_info[2], _classify_secstructure(res[2]).onehot) for node_info in node_list], ), f"Ground truth examples: res {res[1]} {res[0]} is not {(res[2])}." assert np.all( - [np.array_equal(node_info[2], res[3].onehot) for node_info in node_list] + [np.array_equal(node_info[2], res[3].onehot) for node_info in node_list], ), f"Ground truth examples: res {res[1]} {res[0]} is not {res[3]}." -def test_secondary_structure_atom(): +def test_secondary_structure_atom() -> None: test_case = "1ak4" # ATOM list pdb_path = f"tests/data/pdb/{test_case}/{test_case}.pdb" graph, _ = build_testgraph( @@ -92,4 +91,5 @@ def test_secondary_structure_atom(): elif dssp_code in ["G", "H", "I"]: assert np.array_equal(node[2], SecondarySctructure.HELIX.onehot), f"Full file test: res {node[1]}{node[0]} is not a HELIX" else: - raise ValueError(f"Unexpected secondary structure type found at {node[1]}{node[0]}") + msg = f"Unexpected secondary structure type found at {node[1]}{node[0]}" + raise ValueError(msg) diff --git a/tests/features/test_surfacearea.py b/tests/features/test_surfacearea.py index e4efc3912..c4219fc27 100644 --- a/tests/features/test_surfacearea.py +++ b/tests/features/test_surfacearea.py @@ -2,27 +2,39 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.features.surfacearea import add_features +from deeprank2.utils.graph import Graph, Node from . import build_testgraph -def _find_residue_node(graph, chain_id, residue_number): +def _find_residue_node( + graph: Graph, + chain_id: str, + residue_number: int, +) -> Node: for node in graph.nodes: residue = node.id if residue.chain.id == chain_id and residue.number == residue_number: return node - raise ValueError(f"Not found: {chain_id} {residue_number}") + msg = f"Not found: {chain_id} {residue_number}" + raise ValueError(msg) -def _find_atom_node(graph, chain_id, residue_number, atom_name): +def _find_atom_node( + graph: Graph, + chain_id: str, + residue_number: int, + atom_name: str, +) -> Node: for node in graph.nodes: atom = node.id if atom.residue.chain.id == chain_id and atom.residue.number == residue_number and atom.name == atom_name: return node - raise ValueError(f"Not found: {chain_id} {residue_number} {atom_name}") + msg = f"Not found: {chain_id} {residue_number} {atom_name}" + raise ValueError(msg) -def test_bsa_residue(): +def test_bsa_residue() -> None: pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, @@ -37,7 +49,7 @@ def test_bsa_residue(): assert node.features[Nfeat.BSA] > 0.0 -def test_bsa_atom(): +def test_bsa_atom() -> None: pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, @@ -52,7 +64,7 @@ def test_bsa_atom(): assert node.features[Nfeat.BSA] > 0.0 -def test_sasa_residue(): +def test_sasa_residue() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, @@ -75,7 +87,7 @@ def test_sasa_residue(): assert buried_residue_node.features[Nfeat.SASA] < 25.0 -def test_sasa_atom(): +def test_sasa_atom() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, diff --git a/tests/molstruct/test_pair.py b/tests/molstruct/test_pair.py index 19996e42a..2756d9e39 100644 --- a/tests/molstruct/test_pair.py +++ b/tests/molstruct/test_pair.py @@ -1,7 +1,7 @@ from deeprank2.molstruct.pair import Pair -def test_order_independency(): +def test_order_independency() -> None: # These should be the same: pair1 = Pair(1, 2) pair2 = Pair(2, 1) @@ -15,7 +15,7 @@ def test_order_independency(): assert d[pair1] == 2 -def test_uniqueness(): +def test_uniqueness() -> None: # These should be different: pair1 = Pair(1, 2) pair2 = Pair(1, 3) diff --git a/tests/molstruct/test_structure.py b/tests/molstruct/test_structure.py index 77fe8d5ec..6074fc265 100644 --- a/tests/molstruct/test_structure.py +++ b/tests/molstruct/test_structure.py @@ -7,23 +7,23 @@ from deeprank2.utils.buildgraph import get_structure -def _get_structure(path) -> PDBStructure: +def _get_structure(path: str) -> PDBStructure: pdb = pdb2sql(path) try: structure = get_structure(pdb, "101M") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 assert structure is not None return structure -def test_serialization_pickle(): +def test_serialization_pickle() -> None: structure = _get_structure("tests/data/pdb/101M/101M.pdb") s = pickle.dumps(structure) - loaded_structure = pickle.loads(s) # noqa: S301 (suspicious-pickle-usage) + loaded_structure = pickle.loads(s) # noqa: S301 assert loaded_structure == structure assert loaded_structure.get_chain("A") == structure.get_chain("A") @@ -32,7 +32,7 @@ def test_serialization_pickle(): assert loaded_structure.get_chain("A").get_residue(0).atoms[0] == structure.get_chain("A").get_residue(0).atoms[0] -def test_serialization_fork(): +def test_serialization_fork() -> None: structure = _get_structure("tests/data/pdb/101M/101M.pdb") s = _ForkingPickler.dumps(structure) diff --git a/tests/perf/ppi_perf.py b/tests/perf/ppi_perf.py index a184cc606..f1e48e356 100644 --- a/tests/perf/ppi_perf.py +++ b/tests/perf/ppi_perf.py @@ -28,8 +28,8 @@ sizes=[1.0, 1.0, 1.0], ) grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids -# grid_settings = None # noqa: ERA001 (commented out code) -# grid_map_method = None # noqa: ERA001 (commented out code) +# grid_settings = None # noqa: ERA001 +# grid_map_method = None # noqa: ERA001 feature_modules = [components, contact, exposure, irc, secondary_structure, surfacearea] cpu_count = 1 #################################################### @@ -41,7 +41,7 @@ os.makedirs(os.path.join(processed_data_path, "atomic")) -def get_pdb_files_and_target_data(data_path): +def get_pdb_files_and_target_data(data_path: str) -> (list[str], list): csv_data = pd.read_csv(os.path.join(data_path, "BA_values.csv")) pdb_files = glob.glob(os.path.join(data_path, "pdb", "*.pdb")) pdb_files.sort() @@ -69,7 +69,7 @@ def get_pdb_files_and_target_data(data_path): "binary": int(float(bas[i]) <= 500), # binary target value "BA": bas[i], # continuous target value }, - ) + ), ) start = time.perf_counter() diff --git a/tests/perf/srv_perf.py b/tests/perf/srv_perf.py index 6e358a215..9de2d8268 100644 --- a/tests/perf/srv_perf.py +++ b/tests/perf/srv_perf.py @@ -74,8 +74,8 @@ sizes=[1.0, 1.0, 1.0], ) grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids -# grid_settings = None # noqa: ERA001 (commented out code) -# grid_map_method = None # noqa: ERA001 (commented out code) +# grid_settings = None # noqa: ERA001 +# grid_map_method = None # noqa: ERA001 feature_modules = [components, contact, exposure, irc, surfacearea, secondary_structure] cpu_count = 1 #################################################### @@ -87,7 +87,7 @@ os.makedirs(os.path.join(processed_data_path, "atomic")) -def get_pdb_files_and_target_data(data_path): +def get_pdb_files_and_target_data(data_path: str) -> (list[str], list, list, list, list): csv_data = pd.read_csv(os.path.join(data_path, "srv_target_values.csv")) # before running this script change .ent to .pdb pdb_files = glob.glob(os.path.join(data_path, "pdb", "*.pdb")) @@ -129,7 +129,7 @@ def get_pdb_files_and_target_data(data_path): targets={"binary": targets[i]}, radius=radius, distance_cutoff=distance_cutoff, - ) + ), ) start = time.perf_counter() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c5f42d45c..49f1b4b3b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -8,6 +8,7 @@ import numpy as np import pytest import torch +from numpy.typing import NDArray from torch_geometric.loader import DataLoader from deeprank2.dataset import GraphDataset, GridDataset, save_hdf5_keys @@ -30,7 +31,7 @@ def _compute_features_manually( hdf5_path: str, features_transform: dict, feat: str, -): +) -> (NDArray, float, float): """Return specified feature. This function returns the feature specified read from the hdf5 file, after applying manually features_transform dict. @@ -64,7 +65,7 @@ def _compute_features_manually( for entry_name in entry_names ] else: - print(f"Feat {feat} not present in the file.") + warnings.warn(f"Feat {feat} not present in the file.") # apply transformation if transform: @@ -78,7 +79,7 @@ def _compute_features_manually( return arr, mean, dev -def _compute_features_with_get(hdf5_path: str, dataset: GraphDataset): +def _compute_features_with_get(hdf5_path: str, dataset: GraphDataset) -> dict[str, NDArray]: # This function computes features using the Dataset `get` method, # so as they will be seen by the network. It returns a dictionary # whose keys are the features' names and values are the features' values. @@ -131,7 +132,7 @@ def _check_inherited_params( inherited_params: list[str], dataset_train: GraphDataset | GridDataset, dataset_test: GraphDataset | GridDataset, -): +) -> None: dataset_train_vars = vars(dataset_train) dataset_test_vars = vars(dataset_test) @@ -140,10 +141,10 @@ def _check_inherited_params( class TestDataSet(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.hdf5_path = "tests/data/hdf5/1ATN_ppi.hdf5" - def test_collates_entry_names_datasets(self): + def test_collates_entry_names_datasets(self) -> None: for dataset_name, dataset in [ ( "GraphDataset", @@ -167,16 +168,14 @@ def test_collates_entry_names_datasets(self): for batch_data in DataLoader(dataset, batch_size=2, shuffle=True): entry_names += batch_data.entry_names - assert set(entry_names) == set( # noqa: C405 (unnecessary-literal-set) - [ - "residue-ppi-1ATN_1w:A-B", - "residue-ppi-1ATN_2w:A-B", - "residue-ppi-1ATN_3w:A-B", - "residue-ppi-1ATN_4w:A-B", - ], - ), f"entry names of {dataset_name} were not collated correctly" + assert set(entry_names) == { + "residue-ppi-1ATN_1w:A-B", + "residue-ppi-1ATN_2w:A-B", + "residue-ppi-1ATN_3w:A-B", + "residue-ppi-1ATN_4w:A-B", + }, f"entry names of {dataset_name} were not collated correctly" - def test_datasets(self): + def test_datasets(self) -> None: dataset_graph = GraphDataset( hdf5_path=self.hdf5_path, subset=None, @@ -197,7 +196,7 @@ def test_datasets(self): assert len(dataset_grid) == 4 assert dataset_grid[0] is not None - def test_regression_griddataset(self): + def test_regression_griddataset(self) -> None: dataset = GridDataset( hdf5_path=self.hdf5_path, features=[Efeat.VDW, Efeat.ELEC], @@ -218,7 +217,7 @@ def test_regression_griddataset(self): # 1 entry with rmsd value assert dataset[0].y.shape == (1,) - def test_classification_griddataset(self): + def test_classification_griddataset(self) -> None: dataset = GridDataset( hdf5_path=self.hdf5_path, features=[Efeat.VDW, Efeat.ELEC], @@ -239,7 +238,7 @@ def test_classification_griddataset(self): # 1 entry with class value assert dataset[0].y.shape == (1,) - def test_inherit_info_dataset_train_griddataset(self): + def test_inherit_info_dataset_train_griddataset(self) -> None: dataset_train = GridDataset( hdf5_path=self.hdf5_path, features=[Efeat.VDW, Efeat.ELEC], @@ -276,7 +275,7 @@ def test_inherit_info_dataset_train_griddataset(self): dataset_test, ) - def test_inherit_info_pretrained_model_griddataset(self): + def test_inherit_info_pretrained_model_griddataset(self) -> None: # Test the inheritance not giving in any parameters pretrained_model = "tests/data/pretrained/testing_grid_model.pth.tar" dataset_test = GridDataset( @@ -307,7 +306,7 @@ def test_inherit_info_pretrained_model_griddataset(self): for param in dataset_test.inherited_params: assert dataset_test_vars[param] == data[param] - def test_no_target_dataset_griddataset(self): + def test_no_target_dataset_griddataset(self) -> None: hdf5_no_target = "tests/data/hdf5/test_no_target.hdf5" hdf5_target = "tests/data/hdf5/1ATN_ppi.hdf5" pretrained_model = "tests/data/pretrained/testing_grid_model.pth.tar" @@ -328,7 +327,7 @@ def test_no_target_dataset_griddataset(self): with pytest.raises(ValueError): dataset = GridDataset(hdf5_path=hdf5_target, target="CAPRI") - def test_filter_griddataset(self): + def test_filter_griddataset(self) -> None: # filtering out all values with pytest.raises(IndexError): GridDataset( @@ -346,7 +345,7 @@ def test_filter_griddataset(self): ) assert len(dataset) == 3 - def test_filter_graphdataset(self): + def test_filter_graphdataset(self) -> None: # filtering out all values with pytest.raises(IndexError): GraphDataset( @@ -368,7 +367,7 @@ def test_filter_graphdataset(self): ) assert len(dataset) == 3 - def test_multi_file_graphdataset(self): + def test_multi_file_graphdataset(self) -> None: dataset = GraphDataset( hdf5_path=["tests/data/hdf5/train.hdf5", "tests/data/hdf5/valid.hdf5"], node_features=node_feats, @@ -379,7 +378,7 @@ def test_multi_file_graphdataset(self): assert dataset.len() > 0 assert dataset.get(0) is not None - def test_save_external_links_graphdataset(self): + def test_save_external_links_graphdataset(self) -> None: n = 2 with h5py.File("tests/data/hdf5/test.hdf5", "r") as hdf5: @@ -399,7 +398,7 @@ def test_save_external_links_graphdataset(self): for new_id in new_ids: assert new_id in original_ids - def test_save_hard_links_graphdataset(self): + def test_save_hard_links_graphdataset(self) -> None: n = 2 with h5py.File("tests/data/hdf5/test.hdf5", "r") as hdf5: @@ -420,7 +419,7 @@ def test_save_hard_links_graphdataset(self): for new_id in new_ids: assert new_id in original_ids - def test_subset_graphdataset(self): + def test_subset_graphdataset(self) -> None: hdf5 = h5py.File("tests/data/hdf5/train.hdf5", "r") # contains 44 datapoints hdf5_keys = list(hdf5.keys()) n = 10 @@ -443,7 +442,7 @@ def test_subset_graphdataset(self): hdf5.close() - def test_target_transform_graphdataset(self): + def test_target_transform_graphdataset(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/train.hdf5", target="BA", # continuous values --> regression @@ -454,7 +453,7 @@ def test_target_transform_graphdataset(self): for i in range(len(dataset)): assert 0 <= dataset.get(i).y <= 1 - def test_invalid_target_transform_graphdataset(self): + def test_invalid_target_transform_graphdataset(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/train.hdf5", target=targets.BINARY, # --> classification @@ -464,7 +463,7 @@ def test_invalid_target_transform_graphdataset(self): with pytest.raises(ValueError): dataset.get(0) - def test_size_graphdataset(self): + def test_size_graphdataset(self) -> None: hdf5_paths = [ "tests/data/hdf5/train.hdf5", "tests/data/hdf5/valid.hdf5", @@ -482,7 +481,7 @@ def test_size_graphdataset(self): n += len(hdf5_r.keys()) assert len(dataset) == n, f"total data points got was {len(dataset)}" - def test_hdf5_to_pandas_graphdataset(self): + def test_hdf5_to_pandas_graphdataset(self) -> None: # noqa: C901 hdf5_path = "tests/data/hdf5/train.hdf5" dataset = GraphDataset( hdf5_path=hdf5_path, @@ -564,7 +563,7 @@ def test_hdf5_to_pandas_graphdataset(self): assert dataset.df.shape[0] == len(keys[2:]) - def test_save_hist_graphdataset(self): + def test_save_hist_graphdataset(self) -> None: output_directory = mkdtemp() fname = os.path.join(output_directory, "test.png") hdf5_path = "tests/data/hdf5/test.hdf5" @@ -580,7 +579,7 @@ def test_save_hist_graphdataset(self): rmtree(output_directory) - def test_logic_train_graphdataset(self): + def test_logic_train_graphdataset(self) -> None: hdf5_path = "tests/data/hdf5/train.hdf5" # without specifying features_transform in training set @@ -613,7 +612,7 @@ def test_logic_train_graphdataset(self): target="binary", ) - def test_only_transform_graphdataset(self): + def test_only_transform_graphdataset(self) -> None: # define a features_transform dict for only transformations, # including node (bsa) and edge features (electrostatic), # a multi-channel feature (hse) and a case with transform equals to None (sasa) @@ -722,7 +721,7 @@ def test_only_transform_graphdataset(self): assert sorted(checked_features) == sorted(features_transform.keys()) assert len(checked_features) == len(features_transform.keys()) - def test_only_transform_all_graphdataset(self): + def test_only_transform_all_graphdataset(self) -> None: # define a features_transform dict for only transformations for `all` features hdf5_path = "tests/data/hdf5/train.hdf5" @@ -799,7 +798,7 @@ def test_only_transform_all_graphdataset(self): assert sorted(checked_features) == sorted(features) assert len(checked_features) == len(features) - def test_only_standardize_graphdataset(self): + def test_only_standardize_graphdataset(self) -> None: # define a features_transform dict for only standardization, # including node (bsa) and edge features (electrostatic), # a multi-channel feature (hse) and a case with standardize False (sasa) @@ -909,7 +908,7 @@ def test_only_standardize_graphdataset(self): assert sorted(checked_features) == sorted(features_transform.keys()) assert len(checked_features) == len(features_transform.keys()) - def test_only_standardize_all_graphdataset(self): + def test_only_standardize_all_graphdataset(self) -> None: # define a features_transform dict for only standardization for `all` features hdf5_path = "tests/data/hdf5/train.hdf5" features_transform = {"all": {"standardize": True}} @@ -987,7 +986,7 @@ def test_only_standardize_all_graphdataset(self): assert sorted(checked_features) == sorted(features) assert len(checked_features) == len(features) - def test_transform_standardize_graphdataset(self): + def test_transform_standardize_graphdataset(self) -> None: # define a features_transform dict for both transformations and standardization, # including node (bsa) and edge features (electrostatic), # a multi-channel feature (hse) @@ -1094,7 +1093,7 @@ def test_transform_standardize_graphdataset(self): assert sorted(checked_features) == sorted(features_transform.keys()) assert len(checked_features) == len(features_transform.keys()) - def test_features_transform_logic_graphdataset(self): + def test_features_transform_logic_graphdataset(self) -> None: hdf5_path = "tests/data/hdf5/train.hdf5" features_transform = {"all": {"transform": lambda t: np.cbrt(t), "standardize": True}} other_feature_transform = {"all": {"transform": None, "standardize": False}} @@ -1131,7 +1130,7 @@ def test_features_transform_logic_graphdataset(self): assert dataset_train.means == dataset_test.means assert dataset_train.devs == dataset_test.devs - def test_invalid_value_features_transform(self): + def test_invalid_value_features_transform(self) -> None: hdf5_path = "tests/data/hdf5/train.hdf5" features_transform = {"all": {"transform": lambda t: np.log(t + 10), "standardize": True}} @@ -1145,7 +1144,7 @@ def test_invalid_value_features_transform(self): with pytest.raises(ValueError): _compute_features_with_get(hdf5_path, transf_dataset) - def test_inherit_info_dataset_train_graphdataset(self): + def test_inherit_info_dataset_train_graphdataset(self) -> None: hdf5_path = "tests/data/hdf5/train.hdf5" feature_transform = {"all": {"transform": None, "standardize": True}} @@ -1189,7 +1188,7 @@ def test_inherit_info_dataset_train_graphdataset(self): dataset_test, ) - def test_inherit_info_pretrained_model_graphdataset(self): + def test_inherit_info_pretrained_model_graphdataset(self) -> None: hdf5_path = "tests/data/hdf5/test.hdf5" pretrained_model = "tests/data/pretrained/testing_graph_model.pth.tar" dataset_test = GraphDataset( @@ -1202,7 +1201,7 @@ def test_inherit_info_pretrained_model_graphdataset(self): for key in data["features_transform"].values(): if key["transform"] is None: continue - key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 (suspicious-eval-usage) + key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 dataset_test_vars = vars(dataset_test) for param in dataset_test.inherited_params: @@ -1236,7 +1235,7 @@ def test_inherit_info_pretrained_model_graphdataset(self): else: assert dataset_test_vars[param] == data[param] - def test_no_target_dataset_graphdataset(self): + def test_no_target_dataset_graphdataset(self) -> None: hdf5_no_target = "tests/data/hdf5/test_no_target.hdf5" hdf5_target = "tests/data/hdf5/test.hdf5" pretrained_model = "tests/data/pretrained/testing_graph_model.pth.tar" @@ -1260,7 +1259,7 @@ def test_no_target_dataset_graphdataset(self): target="CAPRI", ) - def test_incompatible_dataset_train_type(self): + def test_incompatible_dataset_train_type(self) -> None: dataset_train = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", edge_features=[Efeat.DISTANCE, Efeat.COVALENT], @@ -1274,7 +1273,7 @@ def test_incompatible_dataset_train_type(self): train_source=dataset_train, ) - def test_invalid_pretrained_model_path(self): + def test_invalid_pretrained_model_path(self) -> None: hdf5_graph = "tests/data/hdf5/test.hdf5" with pytest.raises(ValueError): GraphDataset( @@ -1289,7 +1288,7 @@ def test_invalid_pretrained_model_path(self): train_source=hdf5_grid, ) - def test_invalid_pretrained_model_data_type(self): + def test_invalid_pretrained_model_data_type(self) -> None: hdf5_graph = "tests/data/hdf5/test.hdf5" pretrained_grid_model = "tests/data/pretrained/testing_grid_model.pth.tar" with pytest.raises(TypeError): diff --git a/tests/test_integration.py b/tests/test_integration.py index 1fee2ae35..c6ac34ac5 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -31,7 +31,7 @@ count_queries = 3 -def test_cnn(): +def test_cnn() -> None: """ Tests processing several PDB files into their features representation HDF5 file. @@ -124,7 +124,7 @@ def test_cnn(): rmtree(output_directory) -def test_gnn(): +def test_gnn() -> None: """Tests processing several PDB files into their features representation HDF5 file. Then uses HDF5 generated files to train and test a GINet network. @@ -219,7 +219,7 @@ def test_gnn(): @pytest.fixture(scope="session") -def hdf5_files_for_nan(tmpdir_factory): +def hdf5_files_for_nan(tmpdir_factory: str) -> QueryCollection: # For testing cases in which the loss function is nan for the validation and/or for # the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset, # since it is a functionality of the trainer module, which does not depend on the dataset type. @@ -250,13 +250,17 @@ def hdf5_files_for_nan(tmpdir_factory): return queries.process(prefix=prefix) -@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)]) # noqa: PT006 (pytest-parametrize-names-wrong-type) -def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan): +@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)]) # noqa: PT006 +def test_nan_loss_cases( + validate: bool, + best_model: bool, + hdf5_files_for_nan, # noqa: ANN001 +) -> None: mols = [] for fname in hdf5_files_for_nan: with h5py.File(fname, "r") as hdf5: for mol in hdf5: - mols.append(mol) # noqa: PERF402 (manual-list-copy) + mols.append(mol) # noqa: PERF402 dataset_train = GraphDataset( hdf5_path=hdf5_files_for_nan, diff --git a/tests/test_query.py b/tests/test_query.py index 73e9f2834..a194f523e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -25,7 +25,7 @@ def _check_graph_makes_sense( g: Graph, node_feature_names: list[str], edge_feature_names: list[str], -): +) -> None: assert len(g.nodes) > 0, "no nodes" assert Nfeat.POSITION in g.nodes[0].features @@ -34,7 +34,8 @@ def _check_graph_makes_sense( for edge in g.edges: if edge.id.item1 == edge.id.item2: - raise ValueError(f"an edge pairs {edge.id.item1} with itself") + msg = f"an edge pairs {edge.id.item1} with itself" + raise ValueError(msg) assert not g.has_nan() @@ -77,7 +78,7 @@ def _check_graph_makes_sense( os.remove(tmp_path) -def test_interface_graph_residue(): +def test_interface_graph_residue() -> None: query = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3C8P/3C8P.pdb", resolution="residue", @@ -101,7 +102,7 @@ def test_interface_graph_residue(): ) -def test_interface_graph_atomic(): +def test_interface_graph_atomic() -> None: query = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3C8P/3C8P.pdb", resolution="atom", @@ -127,7 +128,7 @@ def test_interface_graph_atomic(): ) -def test_variant_graph_101M(): +def test_variant_graph_101M() -> None: query = SingleResidueVariantQuery( pdb_path="tests/data/pdb/101M/101M.pdb", resolution="atom", @@ -160,7 +161,7 @@ def test_variant_graph_101M(): ) -def test_variant_graph_1A0Z(): +def test_variant_graph_1A0Z() -> None: query = SingleResidueVariantQuery( pdb_path="tests/data/pdb/1A0Z/1A0Z.pdb", resolution="atom", @@ -198,7 +199,7 @@ def test_variant_graph_1A0Z(): ) -def test_variant_graph_9API(): +def test_variant_graph_9API() -> None: query = SingleResidueVariantQuery( pdb_path="tests/data/pdb/9api/9api.pdb", resolution="atom", @@ -234,7 +235,7 @@ def test_variant_graph_9API(): ) -def test_variant_residue_graph_101M(): +def test_variant_residue_graph_101M() -> None: query = SingleResidueVariantQuery( pdb_path="tests/data/pdb/101M/101M.pdb", resolution="residue", @@ -262,7 +263,7 @@ def test_variant_residue_graph_101M(): ) -def test_res_ppi(): +def test_res_ppi() -> None: query = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3MRC/3MRC.pdb", resolution="residue", @@ -272,7 +273,7 @@ def test_res_ppi(): _check_graph_makes_sense(g, [Nfeat.SASA], [Efeat.ELEC]) -def test_augmentation(): +def test_augmentation() -> None: qc = QueryCollection() qc.add( @@ -285,7 +286,7 @@ def test_augmentation(): "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, targets={targets.BINARY: 0}, - ) + ), ) qc.add( @@ -298,7 +299,7 @@ def test_augmentation(): "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, targets={targets.BINARY: 0}, - ) + ), ) qc.add( @@ -312,7 +313,7 @@ def test_augmentation(): variant_amino_acid=aa.alanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, - ) + ), ) qc.add( @@ -327,7 +328,7 @@ def test_augmentation(): pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, influence_radius=3.0, - ) + ), ) augmentation_count = 3 @@ -358,7 +359,7 @@ def test_augmentation(): shutil.rmtree(tmp_dir) -def test_incorrect_pssm_order(): +def test_incorrect_pssm_order() -> None: q = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3C8P/3C8P.pdb", resolution="residue", @@ -382,7 +383,7 @@ def test_incorrect_pssm_order(): _ = q.build(conservation) -def test_incomplete_pssm(): +def test_incomplete_pssm() -> None: q = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3C8P/3C8P.pdb", resolution="residue", @@ -405,7 +406,7 @@ def test_incomplete_pssm(): _ = q.build(conservation) -def test_no_pssm_provided(): +def test_no_pssm_provided() -> None: # pssm_paths is empty dictionary q_empty_dict = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3C8P/3C8P.pdb", @@ -431,7 +432,7 @@ def test_no_pssm_provided(): _ = q_not_provided.build([components]) -def test_incorrect_pssm_provided(): +def test_incorrect_pssm_provided() -> None: # non-existing file q_non_existing = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/3C8P/3C8P.pdb", @@ -463,7 +464,7 @@ def test_incorrect_pssm_provided(): _ = q_missing.build([components]) -def test_variant_query_multiple_chains(): +def test_variant_query_multiple_chains() -> None: q = SingleResidueVariantQuery( pdb_path="tests/data/pdb/2g98/pdb2g98.pdb", resolution="atom", diff --git a/tests/test_querycollection.py b/tests/test_querycollection.py index 94077731f..904796634 100644 --- a/tests/test_querycollection.py +++ b/tests/test_querycollection.py @@ -18,10 +18,10 @@ def _querycollection_tester( query_type: str, n_queries: int = 3, - feature_modules: ModuleType | list[ModuleType] = [components, contact], # noqa: B006 (unsafe default value) + feature_modules: ModuleType | list[ModuleType] | None = None, cpu_count: int = 1, combine_output: bool = True, -): +) -> (QueryCollection, str, list[str]): """ Generic function to test QueryCollection class. @@ -35,6 +35,7 @@ def _querycollection_tester( combine_output (bool): boolean for combining the hdf5 files generated by the processes. By default, the hdf5 files generated are combined into one, and then deleted. """ + feature_modules = feature_modules or [components, contact] if query_type == "ppi": queries = [ ProteinProteinInterfaceQuery( @@ -45,7 +46,7 @@ def _querycollection_tester( "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, - ) + ), ] * n_queries elif query_type == "srv": queries = [ @@ -58,10 +59,11 @@ def _querycollection_tester( wildtype_amino_acid=alanine, variant_amino_acid=phenylalanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, - ) + ), ] * n_queries else: - raise ValueError("Please insert a valid type (either ppi or srv).") + msg = "Please insert a valid type (either ppi or srv)." + raise ValueError(msg) output_directory = mkdtemp() prefix = join(output_directory, "test-process-queries") @@ -97,7 +99,7 @@ def _assert_correct_modules( output_paths: str, features: str | list[str], absent: str, -): +) -> None: """Helper function to assert inclusion of correct features. Args: @@ -119,13 +121,14 @@ def _assert_correct_modules( except KeyError: missing.append(feat) if missing: - raise KeyError(f"The following feature(s) were not created: {missing}.") + msg = f"The following feature(s) were not created: {missing}." + raise KeyError(msg) with pytest.raises(KeyError): _ = f5[next(iter(f5.keys()))][f"{Nfeat.NODE}/{absent}"] -def test_querycollection_process(): +def test_querycollection_process() -> None: """Tests processing method of QueryCollection class.""" for query_type in ["ppi", "srv"]: n_queries = 3 @@ -141,7 +144,7 @@ def test_querycollection_process(): rmtree(output_directory) -def test_querycollection_process_single_feature_module(): +def test_querycollection_process_single_feature_module() -> None: """Test processing for generating from a single feature module. Tested for following input types: ModuleType, list[ModuleType] str, list[str] @@ -153,7 +156,7 @@ def test_querycollection_process_single_feature_module(): rmtree(output_directory) -def test_querycollection_process_all_features_modules(): +def test_querycollection_process_all_features_modules() -> None: """Tests processing for generating all features.""" one_feature_from_each_module = [ Nfeat.RESTYPE, @@ -179,7 +182,7 @@ def test_querycollection_process_all_features_modules(): rmtree(output_directory) -def test_querycollection_process_default_features_modules(): +def test_querycollection_process_default_features_modules() -> None: """Tests processing for generating all features.""" for query_type in ["ppi", "srv"]: _, output_directory, output_paths = _querycollection_tester(query_type) @@ -192,7 +195,7 @@ def test_querycollection_process_default_features_modules(): rmtree(output_directory) -def test_querycollection_process_combine_output_true(): +def test_querycollection_process_combine_output_true() -> None: """Tests processing for combining hdf5 files into one.""" for query_type in ["ppi", "srv"]: modules = [surfacearea, components] @@ -215,7 +218,7 @@ def test_querycollection_process_combine_output_true(): rmtree(output_directory_f) -def test_querycollection_process_combine_output_false(): +def test_querycollection_process_combine_output_false() -> None: """Tests processing for keeping all generated hdf5 files .""" for query_type in ["ppi", "srv"]: cpu_count = 2 @@ -232,7 +235,7 @@ def test_querycollection_process_combine_output_false(): rmtree(output_directory) -def test_querycollection_duplicates_add(): +def test_querycollection_duplicates_add() -> None: """Tests add method of QueryCollection class.""" ref_path = "tests/data/ref/1ATN/1ATN.pdb" pssm_path1 = "tests/data/pssm/1ATN/1ATN.A.pdb.pssm" @@ -261,7 +264,7 @@ def test_querycollection_duplicates_add(): chain_ids=[chain_id1, chain_id2], targets=targets, pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}, - ) + ), ) # check id naming for all pdb files @@ -276,6 +279,6 @@ def test_querycollection_duplicates_add(): "1ATN_2w_2", "1ATN_3w", ] - assert queries._ids_count["residue-ppi:A-B:1ATN_1w"] == 3 # noqa: SLF001 (private member accessed) - assert queries._ids_count["residue-ppi:A-B:1ATN_2w"] == 2 # noqa: SLF001 (private member accessed) - assert queries._ids_count["residue-ppi:A-B:1ATN_3w"] == 1 # noqa: SLF001 (private member accessed) + assert queries._ids_count["residue-ppi:A-B:1ATN_1w"] == 3 # noqa: SLF001 + assert queries._ids_count["residue-ppi:A-B:1ATN_2w"] == 2 # noqa: SLF001 + assert queries._ids_count["residue-ppi:A-B:1ATN_3w"] == 1 # noqa: SLF001 diff --git a/tests/test_set_lossfunction.py b/tests/test_set_lossfunction.py index 191d54820..252fd3269 100644 --- a/tests/test_set_lossfunction.py +++ b/tests/test_set_lossfunction.py @@ -16,11 +16,11 @@ def base_test( - model_path, + model_path: str, trainer: Trainer, - lossfunction=None, - override=False, -): + lossfunction: nn.modules.loss._Loss | None = None, + override: bool = False, +) -> Trainer: if lossfunction: trainer.set_lossfunction(lossfunction=lossfunction, override_invalid=override) @@ -37,16 +37,16 @@ def base_test( class TestLosses(unittest.TestCase): @classmethod - def setUpClass(class_): + def setUpClass(class_) -> None: class_.work_directory = tempfile.mkdtemp() class_.save_path = class_.work_directory + "test.tar" @classmethod - def tearDownClass(class_): + def tearDownClass(class_) -> None: shutil.rmtree(class_.work_directory) # Classification tasks - def test_classif_default(self): + def test_classif_default(self) -> None: dataset = GraphDataset( hdf5_path, target=targets.BINARY, @@ -60,7 +60,7 @@ def test_classif_default(self): assert isinstance(trainer.lossfunction, nn.CrossEntropyLoss) assert isinstance(trainer_pretrained.lossfunction, nn.CrossEntropyLoss) - def test_classif_all(self): + def test_classif_all(self) -> None: dataset = GraphDataset( hdf5_path, target=targets.BINARY, @@ -77,7 +77,7 @@ def test_classif_all(self): assert isinstance(trainer.lossfunction, lossfunction) assert isinstance(trainer_pretrained.lossfunction, lossfunction) - def test_classif_weighted(self): + def test_classif_weighted(self) -> None: dataset = GraphDataset( hdf5_path, target=targets.BINARY, @@ -94,7 +94,7 @@ def test_classif_weighted(self): assert isinstance(trainer_pretrained.lossfunction, lossfunction) assert trainer_pretrained.class_weights - def test_classif_invalid_weighted(self): + def test_classif_invalid_weighted(self) -> None: dataset = GraphDataset( hdf5_path, target=targets.BINARY, @@ -110,7 +110,7 @@ def test_classif_invalid_weighted(self): with pytest.raises(ValueError): base_test(self.save_path, trainer, lossfunction) - def test_classif_invalid_lossfunction(self): + def test_classif_invalid_lossfunction(self) -> None: dataset = GraphDataset( hdf5_path, target=targets.BINARY, @@ -124,7 +124,7 @@ def test_classif_invalid_lossfunction(self): with pytest.raises(ValueError): base_test(self.save_path, trainer, lossfunction) - def test_classif_invalid_lossfunction_override(self): + def test_classif_invalid_lossfunction_override(self) -> None: dataset = GraphDataset(hdf5_path, target=targets.BINARY) trainer = Trainer( neuralnet=NaiveNetwork, @@ -141,7 +141,7 @@ def test_classif_invalid_lossfunction_override(self): ) # Regression tasks - def test_regress_default(self): + def test_regress_default(self) -> None: dataset = GraphDataset( hdf5_path, target="BA", @@ -156,7 +156,7 @@ def test_regress_default(self): assert isinstance(trainer.lossfunction, nn.MSELoss) assert isinstance(trainer_pretrained.lossfunction, nn.MSELoss) - def test_regress_all(self): + def test_regress_all(self) -> None: dataset = GraphDataset( hdf5_path, target="BA", @@ -173,7 +173,7 @@ def test_regress_all(self): assert isinstance(trainer.lossfunction, lossfunction) assert isinstance(trainer_pretrained.lossfunction, lossfunction) - def test_regress_invalid_lossfunction(self): + def test_regress_invalid_lossfunction(self) -> None: dataset = GraphDataset( hdf5_path, target="BA", @@ -188,7 +188,7 @@ def test_regress_invalid_lossfunction(self): with pytest.raises(ValueError): base_test(self.save_path, trainer, lossfunction) - def test_regress_invalid_lossfunction_override(self): + def test_regress_invalid_lossfunction_override(self) -> None: dataset = GraphDataset( hdf5_path, target="BA", diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 96abc7ae2..591cb7cf3 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -39,20 +39,20 @@ def _model_base_test( - save_path, - model_class, - train_hdf5_path, - val_hdf5_path, - test_hdf5_path, - node_features, - edge_features, - task, - target, - target_transform, - output_exporters, - clustering_method, - use_cuda=False, -): + save_path: str, + model_class: torch.nn.Module, + train_hdf5_path: str, + val_hdf5_path: str, + test_hdf5_path: str, + node_features: list[str], + edge_features: list[str], + task: str, + target: str, + target_transform: bool, + output_exporters: list[HDF5OutputExporter], + clustering_method: str, + use_cuda: bool = False, +) -> None: dataset_train = GraphDataset( hdf5_path=train_hdf5_path, node_features=node_features, @@ -128,15 +128,15 @@ def _model_base_test( class TestTrainer(unittest.TestCase): @classmethod - def setUpClass(class_): + def setUpClass(class_) -> None: class_.work_directory = tempfile.mkdtemp() class_.save_path = class_.work_directory + "test.tar" @classmethod - def tearDownClass(class_): + def tearDownClass(class_) -> None: shutil.rmtree(class_.work_directory) - def test_grid_regression(self): + def test_grid_regression(self) -> None: dataset = GridDataset( hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", subset=None, @@ -147,7 +147,7 @@ def test_grid_regression(self): trainer = Trainer(CnnRegression, dataset) trainer.train(nepoch=1, batch_size=2, best_model=False, filename=None) - def test_grid_classification(self): + def test_grid_classification(self) -> None: dataset = GridDataset( hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", subset=None, @@ -163,7 +163,7 @@ def test_grid_classification(self): filename=None, ) - def test_ginet_sigmoid(self): + def test_ginet_sigmoid(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -185,7 +185,7 @@ def test_ginet_sigmoid(self): ) assert len(os.listdir(self.work_directory)) > 0 - def test_ginet(self): + def test_ginet(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -207,7 +207,7 @@ def test_ginet(self): ) assert len(os.listdir(self.work_directory)) > 0 - def test_ginet_class(self): + def test_ginet_class(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -230,7 +230,7 @@ def test_ginet_class(self): assert len(os.listdir(self.work_directory)) > 0 - def test_fout(self): + def test_fout(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -252,7 +252,7 @@ def test_fout(self): ) assert len(os.listdir(self.work_directory)) > 0 - def test_sgat(self): + def test_sgat(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -274,7 +274,7 @@ def test_sgat(self): ) assert len(os.listdir(self.work_directory)) > 0 - def test_naive(self): + def test_naive(self) -> None: files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) @@ -296,7 +296,7 @@ def test_naive(self): ) assert len(os.listdir(self.work_directory)) > 0 - def test_incompatible_regression(self): + def test_incompatible_regression(self) -> None: with pytest.raises(ValueError): _model_base_test( self.save_path, @@ -313,7 +313,7 @@ def test_incompatible_regression(self): "mcl", ) - def test_incompatible_classification(self): + def test_incompatible_classification(self) -> None: with pytest.raises(ValueError): _model_base_test( self.save_path, @@ -336,7 +336,7 @@ def test_incompatible_classification(self): "mcl", ) - def test_incompatible_no_pretrained_no_train(self): + def test_incompatible_no_pretrained_no_train(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY, @@ -348,13 +348,13 @@ def test_incompatible_no_pretrained_no_train(self): dataset_test=dataset, ) - def test_incompatible_no_pretrained_no_Net(self): + def test_incompatible_no_pretrained_no_Net(self) -> None: with pytest.raises(ValueError): _ = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", ) - def test_incompatible_no_pretrained_no_target(self): + def test_incompatible_no_pretrained_no_target(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY, @@ -364,7 +364,7 @@ def test_incompatible_no_pretrained_no_target(self): dataset_train=dataset, ) - def test_incompatible_pretrained_no_test(self): + def test_incompatible_pretrained_no_test(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", clustering_method="mcl", @@ -384,7 +384,7 @@ def test_incompatible_pretrained_no_test(self): pretrained_model=self.save_path, ) - def test_incompatible_pretrained_no_Net(self): + def test_incompatible_pretrained_no_Net(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", clustering_method="mcl", @@ -400,7 +400,7 @@ def test_incompatible_pretrained_no_Net(self): with pytest.raises(ValueError): Trainer(dataset_test=dataset, pretrained_model=self.save_path) - def test_no_training_no_pretrained(self): + def test_no_training_no_pretrained(self) -> None: dataset_train = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", clustering_method="mcl", @@ -417,7 +417,7 @@ def test_no_training_no_pretrained(self): with pytest.raises(ValueError): trainer.test() - def test_no_valid_provided(self): + def test_no_valid_provided(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", clustering_method="mcl", @@ -431,7 +431,7 @@ def test_no_valid_provided(self): assert len(trainer.train_loader) == int(0.75 * len(dataset)) assert len(trainer.valid_loader) == int(0.25 * len(dataset)) - def test_no_test_provided(self): + def test_no_test_provided(self) -> None: dataset_train = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", clustering_method="mcl", @@ -447,7 +447,7 @@ def test_no_test_provided(self): with pytest.raises(ValueError): trainer.test() - def test_no_valid_full_train(self): + def test_no_valid_full_train(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", clustering_method="mcl", @@ -462,7 +462,7 @@ def test_no_valid_full_train(self): assert len(trainer.train_loader) == len(dataset) assert trainer.valid_loader is None - def test_optim(self): + def test_optim(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY, @@ -493,7 +493,7 @@ def test_optim(self): assert trainer_pretrained.lr == lr assert trainer_pretrained.weight_decay == weight_decay - def test_default_optim(self): + def test_default_optim(self) -> None: dataset = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY, @@ -507,7 +507,7 @@ def test_default_optim(self): assert trainer.lr == 0.001 assert trainer.weight_decay == 1e-05 - def test_cuda(self): # test_ginet, but with cuda + def test_cuda(self) -> None: # test_ginet, but with cuda if torch.cuda.is_available(): files = glob.glob(self.work_directory + "/*") for f in files: @@ -535,7 +535,7 @@ def test_cuda(self): # test_ginet, but with cuda warnings.warn("CUDA is not available; test_cuda was skipped") _log.info("CUDA is not available; test_cuda was skipped") - def test_dataset_equivalence_no_pretrained(self): + def test_dataset_equivalence_no_pretrained(self) -> None: # TestCase: dataset_train set (no pretrained model assigned). # Raise error when train dataset is neither a GraphDataset or GridDataset. @@ -575,7 +575,7 @@ def test_dataset_equivalence_no_pretrained(self): dataset_test=dataset_test, ) - def test_dataset_equivalence_pretrained(self): + def test_dataset_equivalence_pretrained(self) -> None: # TestCase: No dataset_train set (pretrained model assigned). # Raise error when no dataset_test is set. @@ -598,7 +598,7 @@ def test_dataset_equivalence_pretrained(self): with pytest.raises(ValueError): Trainer(neuralnet=GINet, pretrained_model=self.save_path) - def test_trainsize(self): + def test_trainsize(self) -> None: hdf5 = "tests/data/hdf5/train.hdf5" hdf5_file = h5py.File(hdf5, "r") # contains 44 datapoints n_val = int(0.25 * len(hdf5_file)) @@ -615,7 +615,7 @@ def test_trainsize(self): hdf5_file.close() - def test_invalid_trainsize(self): + def test_invalid_trainsize(self) -> None: hdf5 = "tests/data/hdf5/train.hdf5" hdf5_file = h5py.File(hdf5, "r") # contains 44 datapoints n = len(hdf5_file) @@ -629,7 +629,7 @@ def test_invalid_trainsize(self): ] for t in test_cases: - print(t) + print(t) # noqa: T201, print in case it fails we can see on which one it failed with pytest.raises(ValueError): _divide_dataset( dataset=GraphDataset(hdf5_path=hdf5), @@ -638,7 +638,7 @@ def test_invalid_trainsize(self): hdf5_file.close() - def test_invalid_cuda_ngpus(self): + def test_invalid_cuda_ngpus(self) -> None: dataset_train = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY) dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) @@ -650,7 +650,7 @@ def test_invalid_cuda_ngpus(self): ngpu=2, ) - def test_invalid_no_cuda_available(self): + def test_invalid_no_cuda_available(self) -> None: if not torch.cuda.is_available(): dataset_train = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY) dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) @@ -667,7 +667,7 @@ def test_invalid_no_cuda_available(self): warnings.warn("CUDA is available; test_invalid_no_cuda_available was skipped") _log.info("CUDA is available; test_invalid_no_cuda_available was skipped") - def test_train_method_no_train(self): + def test_train_method_no_train(self) -> None: # Graphs data test_data_graph = "tests/data/hdf5/test.hdf5" pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" @@ -696,7 +696,7 @@ def test_train_method_no_train(self): with pytest.raises(ValueError): trainer.train() - def test_test_method_pretrained_model_on_dataset_with_target(self): + def test_test_method_pretrained_model_on_dataset_with_target(self) -> None: # Graphs data test_data_graph = "tests/data/hdf5/test.hdf5" pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" @@ -733,7 +733,7 @@ def test_test_method_pretrained_model_on_dataset_with_target(self): output = pd.read_hdf("output_exporter.hdf5", key="testing") assert len(output) == len(dataset_test) - def test_test_method_pretrained_model_on_dataset_without_target(self): + def test_test_method_pretrained_model_on_dataset_without_target(self) -> None: # Graphs data test_data_graph = "tests/data/hdf5/test_no_target.hdf5" pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" diff --git a/tests/tools/test_target.py b/tests/tools/test_target.py index 577b3aad7..a6a09a078 100644 --- a/tests/tools/test_target.py +++ b/tests/tools/test_target.py @@ -10,13 +10,13 @@ class TestTools(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.pdb_path = "./tests/data/pdb/1ATN/" self.pssm_path = "./tests/data/pssm/1ATN/1ATN.A.pdb.pssm" self.ref = "./tests/data/ref/1ATN/" self.h5_graphs = "tests/data/hdf5/1ATN_ppi.hdf5" - def test_add_target(self): + def test_add_target(self) -> None: f, target_path = tempfile.mkstemp(prefix="target", suffix=".lst") os.close(f) f, graph_path = tempfile.mkstemp(prefix="1ATN_ppi", suffix=".hdf5") @@ -34,7 +34,7 @@ def test_add_target(self): os.remove(target_path) os.remove(graph_path) - def test_compute_ppi_scores(self): + def test_compute_ppi_scores(self) -> None: with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -64,7 +64,7 @@ def test_compute_ppi_scores(self): assert scores["binary"] == binary assert scores["capri_class"] == capri - def test_compute_ppi_scores_same_struct(self): + def test_compute_ppi_scores_same_struct(self) -> None: scores = compute_ppi_scores( os.path.join(self.pdb_path, "1ATN_1w.pdb"), os.path.join(self.pdb_path, "1ATN_1w.pdb"), diff --git a/tests/utils/test_buildgraph.py b/tests/utils/test_buildgraph.py index 11c8b989b..35559c6c2 100644 --- a/tests/utils/test_buildgraph.py +++ b/tests/utils/test_buildgraph.py @@ -5,14 +5,14 @@ from deeprank2.utils.buildgraph import get_residue_contact_pairs, get_structure, get_surrounding_residues -def test_get_structure_complete(): +def test_get_structure_complete() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" pdb = pdb2sql(pdb_path) try: structure = get_structure(pdb, "101M") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 assert structure is not None @@ -34,37 +34,37 @@ def test_get_structure_complete(): assert atom.residue == residue -def test_get_structure_from_nmr_with_dna(): +def test_get_structure_from_nmr_with_dna() -> None: pdb_path = "tests/data/pdb/1A6B/1A6B.pdb" pdb = pdb2sql(pdb_path) try: structure = get_structure(pdb, "101M") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 assert structure is not None assert structure.chains[0].residues[0].amino_acid is None # DNA -def test_residue_contact_pairs(): +def test_residue_contact_pairs() -> None: pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" pdb = pdb2sql(pdb_path) try: structure = get_structure(pdb, "1ATN") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 residue_pairs = get_residue_contact_pairs(pdb_path, structure, "A", "B", 8.5) assert len(residue_pairs) > 0 -def test_surrounding_residues(): +def test_surrounding_residues() -> None: pdb_path = "tests/data/pdb/101M/101M.pdb" pdb = pdb2sql(pdb_path) try: structure = get_structure(pdb, "101M") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 all_residues = structure.get_chain("A").residues # A nicely centered residue diff --git a/tests/utils/test_community_pooling.py b/tests/utils/test_community_pooling.py index f362dab3e..3f49aee9c 100644 --- a/tests/utils/test_community_pooling.py +++ b/tests/utils/test_community_pooling.py @@ -12,24 +12,24 @@ class TestCommunity(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long) self.x = torch.tensor([[0], [1], [2], [3], [4], [5]], dtype=torch.float) self.data = Data(x=self.x, edge_index=self.edge_index) - self.data.pos = torch.tensor(np.random.rand(self.data.num_nodes, 3)) # noqa: NPY002 (legacy numpy code) + self.data.pos = torch.tensor(np.random.rand(self.data.num_nodes, 3)) # noqa: NPY002 - def test_detection_mcl(self): + def test_detection_mcl(self) -> None: community_detection(self.data.edge_index, self.data.num_nodes, method="mcl") - def test_detection_louvain(self): + def test_detection_louvain(self) -> None: community_detection(self.data.edge_index, self.data.num_nodes, method="louvain") @unittest.expectedFailure - def test_detection_error(self): + def test_detection_error(self) -> None: community_detection(self.data.edge_index, self.data.num_nodes, method="xxx") - def test_detection_per_batch_mcl(self): + def test_detection_per_batch_mcl(self) -> None: Batch().from_data_list([self.data, self.data]) community_detection_per_batch( self.data.edge_index, @@ -38,7 +38,7 @@ def test_detection_per_batch_mcl(self): method="mcl", ) - def test_detection_per_batch_louvain1(self): + def test_detection_per_batch_louvain1(self) -> None: Batch().from_data_list([self.data, self.data]) community_detection_per_batch( self.data.edge_index, @@ -48,7 +48,7 @@ def test_detection_per_batch_louvain1(self): ) @unittest.expectedFailure - def test_detection_per_batch_louvain2(self): + def test_detection_per_batch_louvain2(self) -> None: Batch().from_data_list([self.data, self.data]) community_detection_per_batch( self.data.edge_index, @@ -57,7 +57,7 @@ def test_detection_per_batch_louvain2(self): method="xxxx", ) - def test_pooling(self): + def test_pooling(self) -> None: batch = Batch().from_data_list([self.data, self.data]) cluster = community_detection(batch.edge_index, batch.num_nodes) diff --git a/tests/utils/test_earlystopping.py b/tests/utils/test_earlystopping.py index 6c26198ea..24402b20e 100644 --- a/tests/utils/test_earlystopping.py +++ b/tests/utils/test_earlystopping.py @@ -4,7 +4,11 @@ dummy_train_losses = [3, 2, 1, 2, 0.5, 2, 3, 4, 5, 1, 7] -def base_earlystopper(patience=10, delta=0, maxgap=None): +def base_earlystopper( + patience: int = 10, + delta: float = 0, + maxgap: float | None = None, +) -> int: early_stopping = EarlyStopping( patience=patience, delta=delta, @@ -14,7 +18,7 @@ def base_earlystopper(patience=10, delta=0, maxgap=None): for ep, loss in enumerate(dummy_val_losses): # check early stopping criteria - print(f"Epoch #{ep}", end=": ") + print(f"Epoch #{ep}", end=": ") # noqa:T201 early_stopping(ep, loss, dummy_train_losses[ep]) if early_stopping.early_stop: break @@ -22,14 +26,14 @@ def base_earlystopper(patience=10, delta=0, maxgap=None): return ep -def test_patience(): +def test_patience() -> None: patience = 3 final_ep = base_earlystopper(patience=patience) # should terminate at epoch 7 assert final_ep == 7 -def test_patience_with_delta(): +def test_patience_with_delta() -> None: patience = 3 delta = 1 final_ep = base_earlystopper(patience=patience, delta=delta) @@ -37,7 +41,7 @@ def test_patience_with_delta(): assert final_ep == 5 -def test_maxgap(): +def test_maxgap() -> None: maxgap = 1 final_ep = base_earlystopper(maxgap=maxgap) # should terminate at epoch 9 diff --git a/tests/utils/test_exporters.py b/tests/utils/test_exporters.py index d49ffdd93..5e3243876 100644 --- a/tests/utils/test_exporters.py +++ b/tests/utils/test_exporters.py @@ -14,13 +14,13 @@ class TestOutputExporters(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self._work_dir = mkdtemp() - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self._work_dir) - def test_collection(self): + def test_collection(self) -> None: exporters = [ TensorboardBinaryClassificationExporter(self._work_dir), HDF5OutputExporter(self._work_dir), @@ -49,7 +49,7 @@ def test_collection(self): assert len(os.listdir(self._work_dir)) == 2 # tensorboard & table @patch("torch.utils.tensorboard.SummaryWriter.add_scalar") - def test_tensorboard_binary_classif(self, mock_add_scalar): + def test_tensorboard_binary_classif(self, mock_add_scalar) -> None: # noqa: ANN001 tensorboard_exporter = TensorboardBinaryClassificationExporter(self._work_dir) pass_name = "test" @@ -60,7 +60,7 @@ def test_tensorboard_binary_classif(self, mock_add_scalar): targets = [0, 1, 1] loss = 0.1 - def _check_scalar(name, scalar, timestep): # noqa: ARG001 (unused argument) + def _check_scalar(name: str, scalar: float, timestep) -> None: # noqa: ARG001, ANN001 if name == f"{pass_name} cross entropy loss": assert scalar < 1.0 else: @@ -72,7 +72,7 @@ def _check_scalar(name, scalar, timestep): # noqa: ARG001 (unused argument) tensorboard_exporter.process(pass_name, epoch_number, entry_names, outputs, targets, loss) assert mock_add_scalar.called - def test_scatter_plot(self): + def test_scatter_plot(self) -> None: scatterplot_exporter = ScatterPlotExporter(self._work_dir) epoch_number = 0 @@ -98,7 +98,7 @@ def test_scatter_plot(self): assert os.path.isfile(scatterplot_exporter.get_filename(epoch_number)) - def test_hdf5_output(self): + def test_hdf5_output(self) -> None: output_exporter = HDF5OutputExporter(self._work_dir) path_output_exporter = os.path.join(self._work_dir, "output_exporter.hdf5") entry_names = ["entry1", "entry2", "entry3"] diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 2e2c4bd05..b792ff29b 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -12,7 +12,7 @@ from deeprank2.domain import edgestorage as Efeat from deeprank2.domain import gridstorage from deeprank2.domain import nodestorage as Nfeat -from deeprank2.domain import targetstorage as Target +from deeprank2.domain import targetstorage as targets from deeprank2.molstruct.pair import ResidueContact from deeprank2.utils.buildgraph import get_structure from deeprank2.utils.graph import Edge, Graph, Node @@ -28,14 +28,14 @@ @pytest.fixture() -def graph(): +def graph() -> Graph: """Build a simple graph of two nodes and one edge in between them.""" # load the structure pdb = pdb2sql("tests/data/pdb/101M/101M.pdb") try: structure = get_structure(pdb, entry_id) finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 # build a contact from two residues residue0 = structure.chains[0].residues[0] @@ -69,7 +69,7 @@ def graph(): return graph -def test_graph_write_to_hdf5(graph): +def test_graph_write_to_hdf5(graph: Graph) -> None: """Test that the graph is correctly written to hdf5 file.""" # create a temporary hdf5 file to write to tmp_dir_path = tempfile.mkdtemp() @@ -102,13 +102,13 @@ def test_graph_write_to_hdf5(graph): assert len(np.nonzero(edge_features_group[Efeat.INDEX][()])) > 0 # target - assert grp[Target.VALUES][target_name][()] == target_value + assert grp[targets.VALUES][target_name][()] == target_value finally: shutil.rmtree(tmp_dir_path) # clean up after the test -def test_graph_write_as_grid_to_hdf5(graph): +def test_graph_write_as_grid_to_hdf5(graph: Graph) -> None: """Test that the graph is correctly written to hdf5 file as a grid.""" # create a temporary hdf5 file to write to tmp_dir_path = tempfile.mkdtemp() @@ -146,13 +146,13 @@ def test_graph_write_as_grid_to_hdf5(graph): assert np.all(data.shape == tuple(grid_settings.points_counts)) # target - assert grp[Target.VALUES][target_name][()] == target_value + assert grp[targets.VALUES][target_name][()] == target_value finally: shutil.rmtree(tmp_dir_path) # clean up after the test -def test_graph_augmented_write_as_grid_to_hdf5(graph): +def test_graph_augmented_write_as_grid_to_hdf5(graph: Graph) -> None: """Test that the graph is correctly written to hdf5 file as a grid.""" # create a temporary hdf5 file to write to tmp_dir_path = tempfile.mkdtemp() @@ -209,7 +209,7 @@ def test_graph_augmented_write_as_grid_to_hdf5(graph): assert np.abs(np.sum(data) - np.sum(unaugmented_data)).item() < 0.2 # target - assert grp[Target.VALUES][target_name][()] == target_value + assert grp[targets.VALUES][target_name][()] == target_value finally: shutil.rmtree(tmp_dir_path) # clean up after the test diff --git a/tests/utils/test_grid.py b/tests/utils/test_grid.py index c1d53eec4..04192549d 100644 --- a/tests/utils/test_grid.py +++ b/tests/utils/test_grid.py @@ -5,7 +5,7 @@ from deeprank2.utils.grid import Grid, GridSettings, MapMethod -def test_grid_orientation(): +def test_grid_orientation() -> None: coord_error_margin = 1.0 # Angstrom points_counts = [10, 10, 10] grid_sizes = [30.0, 30.0, 30.0] @@ -19,7 +19,7 @@ def test_grid_orientation(): target_center = grid_points_group["center"][()] for resolution in VALID_RESOLUTIONS: - print(f"Testing for {resolution} level grids.") # in case pytest fails, this will be printed. + print(f"Testing for {resolution} level grids.") # noqa:T201; in case pytest fails, this will be printed. query = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/1ak4/1ak4.pdb", resolution=resolution, diff --git a/tests/utils/test_pssmdata.py b/tests/utils/test_pssmdata.py index cc25e0b4c..45d7cf0a6 100644 --- a/tests/utils/test_pssmdata.py +++ b/tests/utils/test_pssmdata.py @@ -5,12 +5,12 @@ from deeprank2.utils.parsing.pssm import parse_pssm -def test_add_pssm(): +def test_add_pssm() -> None: pdb = pdb2sql("tests/data/pdb/1ATN/1ATN_1w.pdb") try: structure = get_structure(pdb, "1ATN") finally: - pdb._close() # noqa: SLF001 (private member accessed) + pdb._close() # noqa: SLF001 for chain in structure.chains: with open(f"tests/data/pssm/1ATN/1ATN.{chain.id}.pdb.pssm", encoding="utf-8") as f: