-
Notifications
You must be signed in to change notification settings - Fork 212
Conversation
Hello @justusschock! Thanks for updating this PR.
Comment last updated at 2021-02-22 12:13:55 UTC |
Codecov Report
@@ Coverage Diff @@
## master #130 +/- ##
==========================================
- Coverage 87.39% 6.12% -81.28%
==========================================
Files 49 51 +2
Lines 1579 1846 +267
==========================================
- Hits 1380 113 -1267
- Misses 199 1733 +1534
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@property | ||
def postprocessing_pipeline(self) -> PostProcessingPipeline: | ||
return self._get_pipeline('postprocessing') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Missing setter
"""Pipeline to use when there is no datamodule or it has not defined its pipeline""" | ||
return DataModule.default_pipeline() | ||
return DataModule.default_data_pipeline() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: also do this for postprocessing
@@ -188,3 +210,111 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |||
|
|||
def configure_finetune_callback(self): | |||
return [] | |||
|
|||
### THE FOLLOWING IS A POC FOR DISTRIBUTED PREDICTION | |||
def on_predict_start(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton does it make sense to have a hook like that (I think we need to revisit lightning hooks in general for all stages)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is it called ? I guess we could add hook for predict. Need a bit more exploration there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to on_fit_start, would be called immediately after the Trainer.predict was called
self.postprocessing_pipeline._attach_to_model(self) | ||
|
||
def predict_step(self, batch, batch_idx): | ||
# TODO: Move lightning predict loop from predict to predict_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton You mentioned the prediction API is not final in lightning, right?
IMO it makes sense to rename it to training_step within the LightningModule
, since (similar to train step etc.) it only runs prediction for one batch at a time, making it more of a step (plus we can use the predict
keyword here independently :) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am good with that :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my initial proposal but it got turned down because users are expected to write
model.predict(...)
but not
model.predict_step()
because nobody calls
model.training_step
# TODO: Also use these for trainer creation in training? | ||
# TODO: Have default trainer kwargs per task? | ||
_trainer_kwargs = {} | ||
# TODO: Adjust this to trainer running stage from pl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any thoughts on that @tchaton @aribornstein ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it. We had similar function in previous iteration of predict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I k is we had something like that for training in the beginning. The only downside I See, is That it hides away the Lightning Trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could provide an optional argument for the user to provide trainer in case they don't want to use the default trainer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean another trainer class?
Actually, the trainer class is something I'd hardcode here tbh.
This is one of the very fundamental lightning aspects and I feel if a user wants to change it, he either should look into customization with callbacks/plugins or subclass the task to overwrite it here directly.
flash/data/data_pipeline.py
Outdated
|
||
return AutoDataset(data=data, load_fn=load_fn, load_per_sample=load_per_sample) | ||
|
||
def _generate_loader( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should me make this public API? @tchaton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and it should be to_dataloader
.
) | ||
return model | ||
|
||
def _generate_auto_dset(self, data: Union[Iterable, Any]) -> AutoDataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should me make this public API? @tchaton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think to_dataloader
is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall. I would have to covert some current pipeline to this new API and see how it feels.
return self._get_pipeline('postprocessing') | ||
|
||
def _get_pipeline(self, pipeline_type: str): | ||
pipeline_attr_name = f'{pipeline_type}_pipline' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_pipline
typo ?
@staticmethod | ||
def default_pipeline() -> DataPipeline: | ||
def default_data_pipeline() -> DataPipeline: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we should take the data-type default one ? Example collate for text isn't the same than for vision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but that's why each task would have its own default
|
||
if self.trainer is not None and hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None: | ||
if hasattr(self.trainer.datamodule, | ||
pipeline_attr_name) and getattr(self.trainer.datamodule, pipeline_attr_name is not None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When can pipeline_attr_name
be None ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can't, that should be outside the brackets :)
@@ -188,3 +210,111 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |||
|
|||
def configure_finetune_callback(self): | |||
return [] | |||
|
|||
### THE FOLLOWING IS A POC FOR DISTRIBUTED PREDICTION | |||
def on_predict_start(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is it called ? I guess we could add hook for predict. Need a bit more exploration there.
self.postprocessing_pipeline._attach_to_model(self) | ||
|
||
def predict_step(self, batch, batch_idx): | ||
# TODO: Move lightning predict loop from predict to predict_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am good with that :)
elif post_collate_overriden: | ||
worker_collate = collate_fn | ||
device_collate = self._do_nothing_collate | ||
|
||
elif device_pre_collate_overriden: | ||
worker_collate = self._do_nothing_collate | ||
device_collate = collate_fn | ||
|
||
else: | ||
worker_collate = collate_fn | ||
device_collate = self._do_nothing_collate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif post_collate_overriden: | |
worker_collate = collate_fn | |
device_collate = self._do_nothing_collate | |
elif device_pre_collate_overriden: | |
worker_collate = self._do_nothing_collate | |
device_collate = collate_fn | |
else: | |
worker_collate = collate_fn | |
device_collate = self._do_nothing_collate | |
if device_pre_collate_overriden: | |
worker_collate = self._do_nothing_collate | |
device_collate = collate_fn | |
else: | |
worker_collate = collate_fn | |
device_collate = self._do_nothing_collate |
was_seq = False | ||
|
||
for idx, loader in enumerate(dataloader): | ||
if isinstance(loader, DataLoader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work with custom dataloader. See data_loading.py in Lightning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's why it's guarded like this. But IMO we shouldn't expect any custom loaders here, since then people would be using lightning. Also you cannot pitch this for custom loaders without knowing their internals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't agree that we shouldn't expect custom data loaders, I just ran into a huge issue with this today if I want to extend flash capabilities, I shouldn't have to implement a lightning datamodule from scratch to take advantage of lightnings features. I should be able to extend lightning as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but what kind of interface do you want to assume in that case? E.g. if you have a custom loader class, there might not even be something we can attach to...
|
||
setattr(model, loader_name, dataloader) | ||
|
||
model.transfer_batch_to_device = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this ! Pretty smart !
if auto_collate: | ||
loader_kwargs['collate_fn'] = default_collate | ||
else: | ||
loader_kwargs['collate_fn'] = default_convert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't default_convert
used only for numpy array ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, default convert also converts numbers etc. basically does the same as default_collate without tensor stacking
except AttributeError: | ||
self._data_pipeline = self.default_pipeline() | ||
return self._data_pipeline | ||
return self._get_pipeline('data') | ||
|
||
@data_pipeline.setter | ||
def data_pipeline(self, data_pipeline: DataPipeline) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it a bit confusing to have DataPipeline
and PostProcessingPipeline
as people might expect a PreprocessingPipeline
. Worth to iterate on this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thought so as well. Basically I named it data_pipeline since it does loading + preprocessing. But fine with changing it as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also does postprocessing, with after_uncollate
@tchaton pointed me here. I must admit I don't feel that I figure out the design behind the code. If I may make two observations despite my limited understanding, there are two things where I have the impression that the API you are creating here is not aligned with how I think about my data and my models and how they meet:
Part of this might not be solvable within flash itself but might need amending lightning (in particular, there seems no "auxiliary information" going into the train loop/step except the dataloader). |
@t-vi Thanks for your comments, they are very valuable. Regarding your first point:
Regarding your second point: What kind of API would you expect as a user? Maybe we can look on how we can integrate this kind of API into flash/lightning :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I must be missing lots of context here because this goes in a complete different direction to what we initially discussed for Flash, right?
We agreed on using model.predict
for simple inference an trainer.predict
for distributed inference? Has there been further developments about this? What's the user facing API with these changes?
Also what is PostprocessingPipeline? Why is it separate from DataPipeline?
except AttributeError: | ||
self._data_pipeline = self.default_pipeline() | ||
return self._data_pipeline | ||
return self._get_pipeline('data') | ||
|
||
@data_pipeline.setter | ||
def data_pipeline(self, data_pipeline: DataPipeline) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also does postprocessing, with after_uncollate
self.postprocessing_pipeline._attach_to_model(self) | ||
|
||
def predict_step(self, batch, batch_idx): | ||
# TODO: Move lightning predict loop from predict to predict_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my initial proposal but it got turned down because users are expected to write
model.predict(...)
but not
model.predict_step()
because nobody calls
model.training_step
def predict_step(self, batch, batch_idx): | ||
# TODO: Move lightning predict loop from predict to predict_step | ||
if isinstance(batch, (tuple, list)) and len(batch) == 2: | ||
x, y = batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this should be here in the future
What does this PR do?
This is just some API prototype. So far it is not completely working. It is basically just meant as a discussion starter :)
Fixes #67
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃