Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Datapipeline poc #130

Closed
wants to merge 7 commits into from
Closed

Datapipeline poc #130

wants to merge 7 commits into from

Conversation

justusschock
Copy link
Member

@justusschock justusschock commented Feb 18, 2021

What does this PR do?

This is just some API prototype. So far it is not completely working. It is basically just meant as a discussion starter :)

Fixes #67

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Feb 18, 2021

Hello @justusschock! Thanks for updating this PR.

Line 214:5: E266 too many leading '#' for block comment
Line 271:12: E713 test for membership should be 'not in'
Line 304:121: E501 line too long (124 > 120 characters)

Line 30:121: E501 line too long (138 > 120 characters)
Line 38:121: E501 line too long (132 > 120 characters)
Line 41:121: E501 line too long (205 > 120 characters)
Line 50:121: E501 line too long (205 > 120 characters)
Line 55:121: E501 line too long (140 > 120 characters)
Line 200:121: E501 line too long (160 > 120 characters)

Line 56:121: E501 line too long (140 > 120 characters)
Line 134:121: E501 line too long (174 > 120 characters)

Comment last updated at 2021-02-22 12:13:55 UTC

@codecov
Copy link

codecov bot commented Feb 18, 2021

Codecov Report

Merging #130 (31f65a5) into master (a6edeab) will decrease coverage by 81.27%.
The diff coverage is 9.67%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #130       +/-   ##
==========================================
- Coverage   87.39%   6.12%   -81.28%     
==========================================
  Files          49      51        +2     
  Lines        1579    1846      +267     
==========================================
- Hits         1380     113     -1267     
- Misses        199    1733     +1534     
Flag Coverage Δ
unittests 6.12% <9.67%> (-81.28%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/core/model.py 6.36% <3.94%> (-89.33%) ⬇️
flash/data/postprocessing_pipeline.py 6.41% <6.41%> (ø)
flash/data/data_pipeline.py 15.20% <15.20%> (ø)
flash/text/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
flash/vision/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
flash/text/seq2seq/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
flash/vision/detection/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
flash/vision/embedding/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
flash/text/seq2seq/core/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
flash/vision/classification/model.py 0.00% <0.00%> (-100.00%) ⬇️
... and 39 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a6edeab...17cecb8. Read the comment docs.

@property
def postprocessing_pipeline(self) -> PostProcessingPipeline:
return self._get_pipeline('postprocessing')

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Missing setter

"""Pipeline to use when there is no datamodule or it has not defined its pipeline"""
return DataModule.default_pipeline()
return DataModule.default_data_pipeline()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: also do this for postprocessing

@@ -188,3 +210,111 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

def configure_finetune_callback(self):
return []

### THE FOLLOWING IS A POC FOR DISTRIBUTED PREDICTION
def on_predict_start(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton does it make sense to have a hook like that (I think we need to revisit lightning hooks in general for all stages)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is it called ? I guess we could add hook for predict. Need a bit more exploration there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to on_fit_start, would be called immediately after the Trainer.predict was called

self.postprocessing_pipeline._attach_to_model(self)

def predict_step(self, batch, batch_idx):
# TODO: Move lightning predict loop from predict to predict_step
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton You mentioned the prediction API is not final in lightning, right?

IMO it makes sense to rename it to training_step within the LightningModule, since (similar to train step etc.) it only runs prediction for one batch at a time, making it more of a step (plus we can use the predict keyword here independently :) )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am good with that :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial proposal but it got turned down because users are expected to write

model.predict(...)

but not

model.predict_step()

because nobody calls

model.training_step

Comment on lines +264 to +267
# TODO: Also use these for trainer creation in training?
# TODO: Have default trainer kwargs per task?
_trainer_kwargs = {}
# TODO: Adjust this to trainer running stage from pl
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any thoughts on that @tchaton @aribornstein ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it. We had similar function in previous iteration of predict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I k is we had something like that for training in the beginning. The only downside I See, is That it hides away the Lightning Trainer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could provide an optional argument for the user to provide trainer in case they don't want to use the default trainer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean another trainer class?
Actually, the trainer class is something I'd hardcode here tbh.
This is one of the very fundamental lightning aspects and I feel if a user wants to change it, he either should look into customization with callbacks/plugins or subclass the task to overwrite it here directly.


return AutoDataset(data=data, load_fn=load_fn, load_per_sample=load_per_sample)

def _generate_loader(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should me make this public API? @tchaton

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it should be to_dataloader.

)
return model

def _generate_auto_dset(self, data: Union[Iterable, Any]) -> AutoDataset:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should me make this public API? @tchaton

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to_dataloader is enough.

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall. I would have to covert some current pipeline to this new API and see how it feels.

return self._get_pipeline('postprocessing')

def _get_pipeline(self, pipeline_type: str):
pipeline_attr_name = f'{pipeline_type}_pipline'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_pipline typo ?

@staticmethod
def default_pipeline() -> DataPipeline:
def default_data_pipeline() -> DataPipeline:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we should take the data-type default one ? Example collate for text isn't the same than for vision.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but that's why each task would have its own default


if self.trainer is not None and hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None:
if hasattr(self.trainer.datamodule,
pipeline_attr_name) and getattr(self.trainer.datamodule, pipeline_attr_name is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can pipeline_attr_name be None ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't, that should be outside the brackets :)

@@ -188,3 +210,111 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

def configure_finetune_callback(self):
return []

### THE FOLLOWING IS A POC FOR DISTRIBUTED PREDICTION
def on_predict_start(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is it called ? I guess we could add hook for predict. Need a bit more exploration there.

self.postprocessing_pipeline._attach_to_model(self)

def predict_step(self, batch, batch_idx):
# TODO: Move lightning predict loop from predict to predict_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am good with that :)

Comment on lines +82 to +92
elif post_collate_overriden:
worker_collate = collate_fn
device_collate = self._do_nothing_collate

elif device_pre_collate_overriden:
worker_collate = self._do_nothing_collate
device_collate = collate_fn

else:
worker_collate = collate_fn
device_collate = self._do_nothing_collate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif post_collate_overriden:
worker_collate = collate_fn
device_collate = self._do_nothing_collate
elif device_pre_collate_overriden:
worker_collate = self._do_nothing_collate
device_collate = collate_fn
else:
worker_collate = collate_fn
device_collate = self._do_nothing_collate
if device_pre_collate_overriden:
worker_collate = self._do_nothing_collate
device_collate = collate_fn
else:
worker_collate = collate_fn
device_collate = self._do_nothing_collate

was_seq = False

for idx, loader in enumerate(dataloader):
if isinstance(loader, DataLoader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work with custom dataloader. See data_loading.py in Lightning.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why it's guarded like this. But IMO we shouldn't expect any custom loaders here, since then people would be using lightning. Also you cannot pitch this for custom loaders without knowing their internals

Copy link
Contributor

@aribornstein aribornstein Feb 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree that we shouldn't expect custom data loaders, I just ran into a huge issue with this today if I want to extend flash capabilities, I shouldn't have to implement a lightning datamodule from scratch to take advantage of lightnings features. I should be able to extend lightning as needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but what kind of interface do you want to assume in that case? E.g. if you have a custom loader class, there might not even be something we can attach to...


setattr(model, loader_name, dataloader)

model.transfer_batch_to_device = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this ! Pretty smart !

if auto_collate:
loader_kwargs['collate_fn'] = default_collate
else:
loader_kwargs['collate_fn'] = default_convert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't default_convert used only for numpy array ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, default convert also converts numbers etc. basically does the same as default_collate without tensor stacking

except AttributeError:
self._data_pipeline = self.default_pipeline()
return self._data_pipeline
return self._get_pipeline('data')

@data_pipeline.setter
def data_pipeline(self, data_pipeline: DataPipeline) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing to have DataPipeline and PostProcessingPipeline as people might expect a PreprocessingPipeline. Worth to iterate on this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thought so as well. Basically I named it data_pipeline since it does loading + preprocessing. But fine with changing it as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also does postprocessing, with after_uncollate

@t-vi
Copy link
Contributor

t-vi commented Feb 20, 2021

@tchaton pointed me here. I must admit I don't feel that I figure out the design behind the code.

If I may make two observations despite my limited understanding, there are two things where I have the impression that the API you are creating here is not aligned with how I think about my data and my models and how they meet:

  • I don't think all preprocessing should be tied to the dataloader (/collate fn). To me "same datamodule, different augmentation" or "I have an image from somewhere (e.g. the webcam) but want some part of preprocessing" can happen and it seems unnatural to have to modify the datamodule for this or having to apply parts manually.
  • My impression is that the model here is to deal with stuff passed in to new_predict etc. is to set self.something and then use self as state. As a user this isn't what I'd expect (I would expect that the state is fixed and that I might override parts of it through the args for this one call.).

Part of this might not be solvable within flash itself but might need amending lightning (in particular, there seems no "auxiliary information" going into the train loop/step except the dataloader).

@justusschock
Copy link
Member Author

@t-vi Thanks for your comments, they are very valuable.

Regarding your first point:

  • Why not tie everything to the loader? Creating a loader (especially one with num_workers=0) is almost no overhead.
  • In terms of augmentations: You can for example have them as an init argument for your pipeline.
  • To the datamodule part: You're right and the integration with the datamodule is definitely not perfect. This should just be a base to iterate on :)

Regarding your second point:
You're right to some extend (And the part where you may not be right, is not visible here so let me explain my thoughts behind that). Yes I kind of wanted to use the model as a state, but only temporary. The part that's still missing here (but should definitely come) is to revert back to the original state.

What kind of API would you expect as a user? Maybe we can look on how we can integrate this kind of API into flash/lightning :)

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must be missing lots of context here because this goes in a complete different direction to what we initially discussed for Flash, right?

We agreed on using model.predict for simple inference an trainer.predict for distributed inference? Has there been further developments about this? What's the user facing API with these changes?

Also what is PostprocessingPipeline? Why is it separate from DataPipeline?

except AttributeError:
self._data_pipeline = self.default_pipeline()
return self._data_pipeline
return self._get_pipeline('data')

@data_pipeline.setter
def data_pipeline(self, data_pipeline: DataPipeline) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also does postprocessing, with after_uncollate

self.postprocessing_pipeline._attach_to_model(self)

def predict_step(self, batch, batch_idx):
# TODO: Move lightning predict loop from predict to predict_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial proposal but it got turned down because users are expected to write

model.predict(...)

but not

model.predict_step()

because nobody calls

model.training_step

def predict_step(self, batch, batch_idx):
# TODO: Move lightning predict loop from predict to predict_step
if isinstance(batch, (tuple, list)) and len(batch) == 2:
x, y = batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca carmocca mentioned this pull request Feb 24, 2021
8 tasks
@justusschock justusschock deleted the datapipeline_poc branch March 10, 2021 15:23
@Borda Borda mentioned this pull request Mar 10, 2021
19 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve DataPipeline API
6 participants