Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Workers in the distributed scenario need to see different instances #4241

Merged
merged 37 commits into from
May 21, 2020

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented May 15, 2020

No description provided.

@dirkgr dirkgr requested a review from matt-gardner May 15, 2020 01:19
@dirkgr
Copy link
Member Author

dirkgr commented May 15, 2020

There are no tests and I'm not done testing it manually yet either. That said, @matt-gardner, thoughts?

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The basic approach seems fine to me, other than the use of lambda functions.

read_fn = self._read
if self.max_instances is not None:
# Double lambda ensures that read_fn doesn't call itself recursively.
read_fn = lambda f: (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't use lambda functions with pickle, so this doesn't look like it'll work in the distributed case, which shares objects via pickling. You have to make this a class method if you want to do this redirection. That should also make it so it's much easier to understand, also, because you can just call self._read in that function, instead of passing functions around.

Same comment on the one below. You just need to add two class functions instead of this, and set read_fn = self._read in the default case, and read_fn = self._something_else in the other cases. Though you might want to just use one function for the other cases, having it check self.max_instances on its own; the alternative again requires passing curried functions or lambdas around, which won't work with pickle.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does work. They don't need to be pickled as they are instantiated in the workers. But I'm not sure this is the most readable way of wrapping these functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you certain? Have you tried lazy + distributed together? Passing a lambda function to LazyInstances below was precisely the cause of a problem that we fixed recently: #4026.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think you are right. This only worked in my experiments because I wasn't trying a lazy dataset.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

50 lines of code later, this now also works with lazy datasets.

@@ -209,7 +230,12 @@ def read(self, file_path: str) -> Dataset:
)

# And finally we write to the cache if we need to.
if cache_file and not os.path.exists(cache_file):
if (
self.max_instances is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line made me pause and wonder if it was correct. I think I see why you did this (if you've set that, you're probably testing something, and don't want to cache only a part of the data as if it's the whole thing), but it's not obvious at first glance, so a comment explaining why it's there would be nice.

An alternative to having this check here would be to move the caching logic to above the place where you only keep max_instances (or move the slicing to below this). I'd probably vote for that option, instead of this. Well, that then might defeat the point of the max_instances, because you'd be reading all of the data... I guess it depends on which flag you think takes precedence. There's a fair argument that max_instances should take precedence, and this check should stay as it is. In that case, noting this in the docstring (that max_instances disables saving data to the cache) would be good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because I didn't want to write 500 instances to the cache, and then treat them as the whole dataset when reading the cache. I think the old code would do that.

There are two ways to avoid that. We could make the number of instances part of the filename of the cache, or we could never cache when max_instances is set. The latter seemed easier. It takes almost no time to read 500 instances anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, I ended up at the same place at the end of my comment. We should just add a comment and update the docstring to make this clear.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM when you think it's ready.

@dirkgr dirkgr marked this pull request as ready for review May 21, 2020 03:46
@dirkgr
Copy link
Member Author

dirkgr commented May 21, 2020

@matt-gardner, you already approved this, but I added quite a bit more. Do you want to give it another look?

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your solution with classes seems better to me than the extra functions I was suggesting. LGTM, with just a few minor comments.

In the case that you have an IterableDataset and you call len, the pytorch dataloader
actually spits out a warning - but we need actually calling it to not crash.
"""
return 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we decided in another thread that this should crash, instead of returning 1, didn't we? But you're just moving this logic, not changing it, so if you want to leave that for a separate PR, that's fine with me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's what I was thinking. There is enough going on in this one as it is.

@@ -122,22 +122,29 @@ def __call__(
epoch: int,
batch_number: int,
is_training: bool,
is_master: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this, do you? It's already queryable from the trainer argument. You also didn't add it to the EpochCallback.

epoch,
batches_this_epoch,
is_training=True,
is_master=self._master,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the issue is the opaque private argument, I'd vote for adding a simple is_master() method to the trainer, instead of passing yet another flag here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be easily swayed one way or another in this matter, but my reasoning for having it this way was this: When you implement a BatchCallback, it's easy to forget about the multi-process case, but you almost certainly need to think about it. By making it a parameter, it becomes more visible and harder to ignore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point that I hadn't thought of. Given that, I could also go either way here. Whatever you think is best. Is the epoch callback only ever called from master?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the epoch callback is only called from master. I thought it would make more sense, but open reflection, I'm not sure that's true. I'll make another PR that adds the same thing to the epoch callback.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that wasn't even true. The epoch callback was called all the time. Glad I checked!

pass


BatchCallback.register("null")(BatchCallback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, if that check in FromParams that we added this for bothers us in the future, we should probably just remove it. It's potentially brittle, and this is an easier solution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but in this case it's easy to write a null implementation. Not all classes are like that. But I guess we can always throw NotImplementedError().

for metadata in batch["metadata"]:
logger.info(f"First word from training data: '{metadata['words'][0]}'")

def in_worker(self, *args, **kwargs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this doing? I don't see it called anywhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that was a leftover from an earlier iteration. I removed it.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants