Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve Trainer and DeeprankDataset logic for production testing #515

Merged
merged 63 commits into from
Jan 3, 2024

Conversation

gcroci2
Copy link
Collaborator

@gcroci2 gcroci2 commented Oct 19, 2023

Main changes:

  • Now DeeprankDataset takes as input train_data, that before was called dataset_train. Now train_data can be a DeeprankDataset representing the training set (as before), or a pre-trained model (new feature). It needs to be set only if train is False (as before, so only in validation/testing sets cases). Now we are able to use a test dataset without the need for the original model's training dataset. We can use the info stored in the pre-trained model to inherit the needed attributes.
    • There were many issues with loading the lambda transformations from the torch-saved pre-trained model (testing only use case). I ended up converting the lambdas to strings first, and then saving them. When the pre-trained model is loaded in the GraphDataset class, the strings representing the lambdas are evaluated and converted back to functions.

Secondary changes:

  • In DeeprankDataset classes, if target attribute is present (e.g., binary, inherited) but it's not in the HDF5 and we're not in the training phase, now no error is raised. Indeed, it should be possible to run a pre-trained model on data point/s even if the target value/s are not present, for doing predictions only. It's actually a typical test-case scenario, in which we don't have any labels for the new data points that we want to evaluate.
  • On the other hand, if we're in the training phase (train = True) and no target is set, or the set target is not in the hdf5 file/s, then a ValueError is raised.
  • self.pretrained_model_path is now defined in the init of the Trainer class and defaulted to pretrained_model.
  • self.model_load_state_dict is also defined in the init of the Trainer class and defaulted to None. It is assigned to a value only in the case of a pre-trained model or at the end of the training phase. This way in the test() method we can first verify if the model has actually been loaded (pre-trained case) or trained. If not, the test() method now throws an error.
  • In the Trainer class' init, before loading parameters and the pre-trained model there was a check for the target, which in the pre-trained model case could be not present at all (it is saved in the model itself, no need to define it in the Trainer instance). I removed the check.
  • I added the following attributes in _init_from_dataset, which need to be saved in the model's file for those cases in which we want to test it on some other data without redefining the training set: features_transform, means, devs, target_transform, classes, classes_to_index. I also added them to the model's state dict which is saved at the end of the training (_save_model, same for _load_params).
  • Now the saved model contains a key called data_type, needed for checking which type of dataset was used during the training of the model.
  • I removed the warning about not having a validation set during the training because it was given at each epoch. Now it's printed only once when you call the train() method.
  • When you call torch.load() is called on a model's file which contains GPU tensors, those tensors will be loaded to GPU by default. But if no Cuda was available, the code crashed. Now there is a check for that, and in case Cuda is not available the tensors are loaded into the CPU.
  • target_filter wasn't really working. Some other edits made me notice that the functionality was broken, and it's fixed now.

Still to solve:

  • For some reason, PyTorch gives a weird error about the Adam optimizer, but only in the 3.11 Python version. I tried to fix the relevant torch packages versions, but it's still failing. I haven't touched anything about the optimizer though. Any idea? @DaniBodor

@gcroci2 gcroci2 self-assigned this Oct 19, 2023
@gcroci2 gcroci2 linked an issue Oct 19, 2023 that may be closed by this pull request
5 tasks
@gcroci2 gcroci2 changed the title Improve Trainer logic for production testing refactor: improve Trainer and DeeprankDataset logic for production testing Oct 19, 2023
…te but no target values are present in the hdf5 file/s
@DaniBodor
Copy link
Collaborator

I unsubscribed to notifications for this PR for now. Please tag me again if needed and/or when you want me to re-review.

@gcroci2 gcroci2 requested a review from DaniBodor November 22, 2023 14:49
docs/getstarted.md Outdated Show resolved Hide resolved
deeprank2/dataset.py Outdated Show resolved Hide resolved
tests/test_dataset.py Outdated Show resolved Hide resolved
docs/getstarted.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@DaniBodor DaniBodor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving these as comments for now. Once you/we figure out why the build is failing, I will review that before approving.

Co-authored-by: Dani Bodor <d.bodor@esciencecenter.nl>
@DaniBodor
Copy link
Collaborator

It looks to me like the problem with the 3.11 build is really a core change in pytorch. I don't think it'll be easy for us (def not me) to figure out what the problem is. Maybe it's best to create an issue on pytorch and see if they know how to solve.

Copy link

This PR is stale because it has been open for 14 days with no activity.

@github-actions github-actions bot added the stale issue not touched from too much time label Dec 13, 2023
@gcroci2
Copy link
Collaborator Author

gcroci2 commented Jan 3, 2024

I am merging this PR. The issue with Python 3.11 will be solved in another PR.

@gcroci2 gcroci2 merged commit 226ff35 into dev Jan 3, 2024
5 of 7 checks passed
@gcroci2 gcroci2 deleted the 510_testing_pre_trained_gcroci2 branch January 3, 2024 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale issue not touched from too much time
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve Trainer and DeeprankDataset for production testing
2 participants