-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Workers in the distributed scenario need to see different instances #4241
Conversation
Flake still sucks though. It doesn't like assigning lambdas.
There are no tests and I'm not done testing it manually yet either. That said, @matt-gardner, thoughts? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The basic approach seems fine to me, other than the use of lambda functions.
read_fn = self._read | ||
if self.max_instances is not None: | ||
# Double lambda ensures that read_fn doesn't call itself recursively. | ||
read_fn = lambda f: ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't use lambda functions with pickle, so this doesn't look like it'll work in the distributed case, which shares objects via pickling. You have to make this a class method if you want to do this redirection. That should also make it so it's much easier to understand, also, because you can just call self._read
in that function, instead of passing functions around.
Same comment on the one below. You just need to add two class functions instead of this, and set read_fn = self._read
in the default case, and read_fn = self._something_else
in the other cases. Though you might want to just use one function for the other cases, having it check self.max_instances
on its own; the alternative again requires passing curried functions or lambdas around, which won't work with pickle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does work. They don't need to be pickled as they are instantiated in the workers. But I'm not sure this is the most readable way of wrapping these functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you certain? Have you tried lazy + distributed together? Passing a lambda function to LazyInstances
below was precisely the cause of a problem that we fixed recently: #4026.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think you are right. This only worked in my experiments because I wasn't trying a lazy dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
50 lines of code later, this now also works with lazy datasets.
@@ -209,7 +230,12 @@ def read(self, file_path: str) -> Dataset: | |||
) | |||
|
|||
# And finally we write to the cache if we need to. | |||
if cache_file and not os.path.exists(cache_file): | |||
if ( | |||
self.max_instances is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line made me pause and wonder if it was correct. I think I see why you did this (if you've set that, you're probably testing something, and don't want to cache only a part of the data as if it's the whole thing), but it's not obvious at first glance, so a comment explaining why it's there would be nice.
An alternative to having this check here would be to move the caching logic to above the place where you only keep max_instances
(or move the slicing to below this). I'd probably vote for that option, instead of this. Well, that then might defeat the point of the max_instances
, because you'd be reading all of the data... I guess it depends on which flag you think takes precedence. There's a fair argument that max_instances
should take precedence, and this check should stay as it is. In that case, noting this in the docstring (that max_instances
disables saving data to the cache) would be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this because I didn't want to write 500 instances to the cache, and then treat them as the whole dataset when reading the cache. I think the old code would do that.
There are two ways to avoid that. We could make the number of instances part of the filename of the cache, or we could never cache when max_instances
is set. The latter seemed easier. It takes almost no time to read 500 instances anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree, I ended up at the same place at the end of my comment. We should just add a comment and update the docstring to make this clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM when you think it's ready.
@matt-gardner, you already approved this, but I added quite a bit more. Do you want to give it another look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your solution with classes seems better to me than the extra functions I was suggesting. LGTM, with just a few minor comments.
In the case that you have an IterableDataset and you call len, the pytorch dataloader | ||
actually spits out a warning - but we need actually calling it to not crash. | ||
""" | ||
return 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we decided in another thread that this should crash, instead of returning 1, didn't we? But you're just moving this logic, not changing it, so if you want to leave that for a separate PR, that's fine with me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's what I was thinking. There is enough going on in this one as it is.
@@ -122,22 +122,29 @@ def __call__( | |||
epoch: int, | |||
batch_number: int, | |||
is_training: bool, | |||
is_master: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need this, do you? It's already queryable from the trainer
argument. You also didn't add it to the EpochCallback
.
epoch, | ||
batches_this_epoch, | ||
is_training=True, | ||
is_master=self._master, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the issue is the opaque private argument, I'd vote for adding a simple is_master()
method to the trainer, instead of passing yet another flag here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be easily swayed one way or another in this matter, but my reasoning for having it this way was this: When you implement a BatchCallback
, it's easy to forget about the multi-process case, but you almost certainly need to think about it. By making it a parameter, it becomes more visible and harder to ignore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point that I hadn't thought of. Given that, I could also go either way here. Whatever you think is best. Is the epoch callback only ever called from master?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the epoch callback is only called from master. I thought it would make more sense, but open reflection, I'm not sure that's true. I'll make another PR that adds the same thing to the epoch callback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, that wasn't even true. The epoch callback was called all the time. Glad I checked!
pass | ||
|
||
|
||
BatchCallback.register("null")(BatchCallback) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI, if that check in FromParams
that we added this for bothers us in the future, we should probably just remove it. It's potentially brittle, and this is an easier solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, but in this case it's easy to write a null
implementation. Not all classes are like that. But I guess we can always throw NotImplementedError()
.
tests/commands/train_test.py
Outdated
for metadata in batch["metadata"]: | ||
logger.info(f"First word from training data: '{metadata['words'][0]}'") | ||
|
||
def in_worker(self, *args, **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this doing? I don't see it called anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, that was a leftover from an earlier iteration. I removed it.
No description provided.