-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: improve Trainer
and DeeprankDataset
logic for production testing
#515
Conversation
Trainer
logic for production testingTrainer
and DeeprankDataset
logic for production testing
…in when train is False in dataset.py
…ns (much more reliable)
…rget values are present in the hdf5 file/s
…te but no target values are present in the hdf5 file/s
I unsubscribed to notifications for this PR for now. Please tag me again if needed and/or when you want me to re-review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just leaving these as comments for now. Once you/we figure out why the build is failing, I will review that before approving.
Co-authored-by: Dani Bodor <d.bodor@esciencecenter.nl>
It looks to me like the problem with the 3.11 build is really a core change in pytorch. I don't think it'll be easy for us (def not me) to figure out what the problem is. Maybe it's best to create an issue on pytorch and see if they know how to solve. |
This PR is stale because it has been open for 14 days with no activity. |
I am merging this PR. The issue with Python 3.11 will be solved in another PR. |
Main changes:
DeeprankDataset
takes as inputtrain_data
, that before was calleddataset_train
. Nowtrain_data
can be aDeeprankDataset
representing the training set (as before), or a pre-trained model (new feature). It needs to be set only iftrain
is False (as before, so only in validation/testing sets cases). Now we are able to use a test dataset without the need for the original model's training dataset. We can use the info stored in the pre-trained model to inherit the needed attributes.GraphDataset
class, the strings representing the lambdas are evaluated and converted back to functions.Secondary changes:
DeeprankDataset
classes, iftarget
attribute is present (e.g.,binary
, inherited) but it's not in the HDF5 and we're not in the training phase, now no error is raised. Indeed, it should be possible to run a pre-trained model on data point/s even if the target value/s are not present, for doing predictions only. It's actually a typical test-case scenario, in which we don't have any labels for the new data points that we want to evaluate.train = True
) and no target is set, or the set target is not in the hdf5 file/s, then aValueError
is raised.self.pretrained_model_path
is now defined in the init of theTrainer
class and defaulted topretrained_model
.self.model_load_state_dict
is also defined in the init of theTrainer
class and defaulted toNone
. It is assigned to a value only in the case of a pre-trained model or at the end of the training phase. This way in thetest()
method we can first verify if the model has actually been loaded (pre-trained case) or trained. If not, thetest()
method now throws an error.Trainer
instance). I removed the check._init_from_dataset
, which need to be saved in the model's file for those cases in which we want to test it on some other data without redefining the training set:features_transform
,means
,devs
,target_transform
,classes
,classes_to_index
. I also added them to the model'sstate
dict which is saved at the end of the training (_save_model
, same for_load_params
).data_type
, needed for checking which type of dataset was used during the training of the model.train()
method.target_filter
wasn't really working. Some other edits made me notice that the functionality was broken, and it's fixed now.Still to solve: