-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Annotate tree_util #9079
Annotate tree_util #9079
Conversation
dd21e1b
to
058274e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks as always, Neil.
I have suggestions which fall into two main categories:
- whitespace
- having more specific output type annotations (lists rather than sequences).
I'm interested if there are any strong opinions against 2, but it seems preferable to be specific about output types (unless we thought they were likely to change).
jax/_src/ad_checkpoint.py
Outdated
@@ -224,7 +224,8 @@ def fun_remat(*args, **kwargs): | |||
### Utilities | |||
|
|||
def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]: | |||
args, in_tree = tree_flatten((args, kwargs)) | |||
args_seq, in_tree = tree_flatten((args, kwargs)) | |||
args = tuple(args_seq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious, was this change actually needed? It seems surprising. (It'd be cleaner not to add the extra line!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was tripping a typing error on the commit hook. We could get rid of the extra line by keeping
args_seq, in_tree = tree_flatten((args, kwargs))
here, but then changing the two uses of it to *args_seq
and len(args_seq)
below. Would that be preferable? I was trying to be minimally invasive, but you're right that shorter is probably clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason that this is a type error by the way is because args
changes type. As an argument, it's a tuple, but the result of tree_flatten
is either a sequence or a list—neither of which is a subclass of tuple. To be honest, it might be better to give this a whole new name anyway since it doesn't correspond to just args
anyway. Maybe in_leaves
to match in_tree
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, makes sense! in_leaves
sgtm.
Of course, I'm happy to fix all whitespace to match your style 😄
I don't have any strong opinion, but I should have commented on this in the change. I initially wrote the annotations using I feel like Promising sequence means that you could one day choose a different sequence type than list. For example, you might provide some alternative structure that supports easier tree traversal. Maybe that was the motivation of whoever annotated Please let me know what you prefer. |
I'm not so sure about this in general. Why not promise even less and say it's just an Iterable? We want to promise the right amount, and IMO there's no harm in continuing to say it's a list, as our docs already did. Or at least, there wouldn't be except for the other issue you mentioned.
Ah I see. Let's keep it how your PR does it then! I don't mean to add any extra work. |
Yes, you're absolutely right of course.
Okay, thanks! I'll make the change above, squash, and let you know when it's done. |
You're so generous and easy to work with, Neil! It's always a pleasure, and I always learn something. You're the best. |
768784e
to
c1c89a0
Compare
Done, and thank you, you all are a pleasure to work with as well! |
Unfortunately, internal typechecks (i.e. typechecks spanning all google code which uses these functions!) ran into some issues with this. As one example, jraph does addition on the leaves outputs of Maybe we could put these annotations under |
That makes sense. That's why I had to change a couple instances of concatenation in the Jax code to use generalized unpacking.
Unfortunately, that won't solve the type checking errors since those run under TYPE_CHECKING too. I can see a few options:
Option 1 means changing Jaxlib too (a couple small changes here). Also, just glancing over the interface, it's a bit weird that Option 2 adds extra casts that weren't there before, and is awkward having two mismatched interfaces. It's the least work, but my least favorite. Option 3 depends on how many type errors you ended up with. Another question goes back to your point above. What do you want in an ideal world? Would you rather promise list and people can easily concatenate them or would you rather promise sequence and you're free to use another container down the road? There may be other options, but those are the obvious ones to me. Please let me know what you decide and I'll repair this change. |
Thanks for that super clear analysis. I think Option 3 would be pretty annoying. I like Option 1 but I wonder if @hawkinsp has a different opinion. Peter, WDYT? |
@hawkinsp If you like, I can make a corresponding pull request to tensorflow's jaxlib to make the interface promise lists? |
c1c89a0
to
95f6941
Compare
Sure, that's fine. Just tag me in the PR so I see it. |
41945af
to
3f627b4
Compare
3f627b4
to
b6ee804
Compare
b6ee804
to
a0cd677
Compare
a0cd677
to
ba547b2
Compare
f1ca1a7
to
6593591
Compare
6593591
to
d86ace7
Compare
60e92ee
to
6aabaa3
Compare
This should pass when Tensorflow 2.10's version of jaxlib is used thanks to tensorflow/tensorflow#54330 being compiled into it. |
6aabaa3
to
2e8d5ee
Compare
21a1c88
to
0a4ce9e
Compare
0a4ce9e
to
78500f3
Compare
6939460
to
c1d6f7d
Compare
c1d6f7d
to
b742b04
Compare
@mattjj Would it be possible to get this merged now that the tensorflow type annotations are in jaxlib? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
(taking over review/merge, since this is related to #12049) |
No description provided.