Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate tree_util #9079

Merged
merged 1 commit into from
Oct 27, 2022
Merged

Annotate tree_util #9079

merged 1 commit into from
Oct 27, 2022

Conversation

NeilGirdhar
Copy link
Contributor

No description provided.

@NeilGirdhar NeilGirdhar force-pushed the annotate_tree branch 7 times, most recently from dd21e1b to 058274e Compare January 6, 2022 04:07
@mattjj mattjj self-assigned this Jan 7, 2022
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks as always, Neil.

I have suggestions which fall into two main categories:

  1. whitespace
  2. having more specific output type annotations (lists rather than sequences).

I'm interested if there are any strong opinions against 2, but it seems preferable to be specific about output types (unless we thought they were likely to change).

jax/_src/tree_util.py Outdated Show resolved Hide resolved
jax/_src/tree_util.py Outdated Show resolved Hide resolved
jax/_src/tree_util.py Outdated Show resolved Hide resolved
jax/_src/tree_util.py Outdated Show resolved Hide resolved
jax/_src/tree_util.py Show resolved Hide resolved
@@ -224,7 +224,8 @@ def fun_remat(*args, **kwargs):
### Utilities

def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
args, in_tree = tree_flatten((args, kwargs))
args_seq, in_tree = tree_flatten((args, kwargs))
args = tuple(args_seq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, was this change actually needed? It seems surprising. (It'd be cleaner not to add the extra line!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was tripping a typing error on the commit hook. We could get rid of the extra line by keeping

args_seq, in_tree = tree_flatten((args, kwargs))

here, but then changing the two uses of it to *args_seq and len(args_seq) below. Would that be preferable? I was trying to be minimally invasive, but you're right that shorter is probably clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that this is a type error by the way is because args changes type. As an argument, it's a tuple, but the result of tree_flatten is either a sequence or a list—neither of which is a subclass of tuple. To be honest, it might be better to give this a whole new name anyway since it doesn't correspond to just args anyway. Maybe in_leaves to match in_tree?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes sense! in_leaves sgtm.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 7, 2022
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Jan 8, 2022

Nice! Thanks as always, Neil.

I have suggestions which fall into two main categories:

  1. whitespace

Of course, I'm happy to fix all whitespace to match your style 😄

  1. having more specific output type annotations (lists rather than sequences).

I'm interested if there are any strong opinions against 2, but it seems preferable to be specific about output types (unless we thought they were likely to change).

I don't have any strong opinion, but I should have commented on this in the change. I initially wrote the annotations using List, but it turns out that the underlying jaxlib functions are already annotated using Sequence. This caused typing errors in the tree_util code.

I feel like jaxlib and tree_util should match: either jaxlib should change its annotations to List, or else tree_util should use Sequence to match.

Promising sequence means that you could one day choose a different sequence type than list. For example, you might provide some alternative structure that supports easier tree traversal. Maybe that was the motivation of whoever annotated jaxlib: keep options open. I agree with the jaxlib author that promising less is usually better.

Please let me know what you prefer.

@mattjj
Copy link
Collaborator

mattjj commented Jan 8, 2022

I agree with the jaxlib author that promising less is usually better.

I'm not so sure about this in general. Why not promise even less and say it's just an Iterable? We want to promise the right amount, and IMO there's no harm in continuing to say it's a list, as our docs already did. Or at least, there wouldn't be except for the other issue you mentioned.

but it turns out that the underlying jaxlib functions are already annotated using Sequence. This caused typing errors in the tree_util code.

Ah I see. Let's keep it how your PR does it then! I don't mean to add any extra work.

@NeilGirdhar
Copy link
Contributor Author

I'm not so sure about this in general. Why not promise even less and say it's just an Iterable? We want to promise the right amount,

Yes, you're absolutely right of course.

Ah I see. Let's keep it how your PR does it then! I don't mean to add any extra work.

Okay, thanks! I'll make the change above, squash, and let you know when it's done.

@mattjj
Copy link
Collaborator

mattjj commented Jan 8, 2022

You're so generous and easy to work with, Neil! It's always a pleasure, and I always learn something. You're the best.

@NeilGirdhar NeilGirdhar force-pushed the annotate_tree branch 2 times, most recently from 768784e to c1c89a0 Compare January 8, 2022 04:32
@NeilGirdhar
Copy link
Contributor Author

Done, and thank you, you all are a pleasure to work with as well!

@mattjj
Copy link
Collaborator

mattjj commented Jan 8, 2022

Unfortunately, internal typechecks (i.e. typechecks spanning all google code which uses these functions!) ran into some issues with this.

As one example, jraph does addition on the leaves outputs of tree_flatten calls. But that doesn't typecheck against Sequence, since you can't add Sequences together!

Maybe we could put these annotations under if TYPE_CHECKING? Got a better idea?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Jan 8, 2022

As one example, jraph does addition on the leaves outputs of tree_flatten calls. But that doesn't typecheck against Sequence, since you can't add Sequences together!

That makes sense. That's why I had to change a couple instances of concatenation in the Jax code to use generalized unpacking.

Maybe we could put these annotations under if TYPE_CHECKING? Got a better idea?

Unfortunately, that won't solve the type checking errors since those run under TYPE_CHECKING too. I can see a few options:

  1. Change tree-util annotations to use lists, and change jaxlib to match,
  2. Change tree-util annotations to use lists, and add casts throughout so that jaxlib doesn't complain, or
  3. Change the client code to cast results to list where necessary.

Option 1 means changing Jaxlib too (a couple small changes here). Also, just glancing over the interface, it's a bit weird that flatten returns a sequence, but flatten_up_to returns a list. Maybe someone thought they could do some optimization by returning immutable structures without casting to list? On the other hand, it might be worth making the interface homogenous.

Option 2 adds extra casts that weren't there before, and is awkward having two mismatched interfaces. It's the least work, but my least favorite.

Option 3 depends on how many type errors you ended up with.

Another question goes back to your point above. What do you want in an ideal world? Would you rather promise list and people can easily concatenate them or would you rather promise sequence and you're free to use another container down the road?

There may be other options, but those are the obvious ones to me. Please let me know what you decide and I'll repair this change.

@mattjj
Copy link
Collaborator

mattjj commented Jan 8, 2022

Thanks for that super clear analysis.

I think Option 3 would be pretty annoying.

I like Option 1 but I wonder if @hawkinsp has a different opinion. Peter, WDYT?

@NeilGirdhar
Copy link
Contributor Author

@hawkinsp If you like, I can make a corresponding pull request to tensorflow's jaxlib to make the interface promise lists?

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 3, 2022

Sure, that's fine. Just tag me in the PR so I see it.

@NeilGirdhar NeilGirdhar force-pushed the annotate_tree branch 2 times, most recently from f1ca1a7 to 6593591 Compare June 14, 2022 19:43
@NeilGirdhar NeilGirdhar force-pushed the annotate_tree branch 2 times, most recently from 60e92ee to 6aabaa3 Compare September 6, 2022 21:41
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Sep 6, 2022

This should pass when Tensorflow 2.10's version of jaxlib is used thanks to tensorflow/tensorflow#54330 being compiled into it.

@NeilGirdhar NeilGirdhar force-pushed the annotate_tree branch 2 times, most recently from 21a1c88 to 0a4ce9e Compare September 26, 2022 21:08
@NeilGirdhar NeilGirdhar force-pushed the annotate_tree branch 6 times, most recently from 6939460 to c1d6f7d Compare October 26, 2022 17:26
@NeilGirdhar
Copy link
Contributor Author

@mattjj Would it be possible to get this merged now that the tensorflow type annotations are in jaxlib?

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 26, 2022

(taking over review/merge, since this is related to #12049)

@copybara-service copybara-service bot merged commit 9abacbd into jax-ml:main Oct 27, 2022
@NeilGirdhar NeilGirdhar deleted the annotate_tree branch October 27, 2022 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants