Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Pytrees are Trees #3768

Merged
merged 1 commit into from
Mar 21, 2024
Merged

[nnx] Pytrees are Trees #3768

merged 1 commit into from
Mar 21, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 18, 2024

What does this PR do?

  • Traversal procedure no longer caches pytree nodes, this reintroduces referential transparency for pytree types.
  • list, dict, and tuple are no longer graph nodes, they are they are supported as tree nodes.

Other changes

  • Renames MutableNodeImpl to GraphNodeImpl.
  • Renames ImmutableNodeImpl to PytreeNodeImpl.
  • Renames Sequence to List for consistency and removes the __call__ method.
  • Created Sequential that inherits form List and implements __call__.

@cgarciae cgarciae force-pushed the nnx-improve-graph-update branch 2 times, most recently from 7694ab8 to c8cf3f5 Compare March 19, 2024 13:16
@cgarciae cgarciae force-pushed the nnx-pytree-are-trees branch from 32c5836 to 4aa5ed2 Compare March 19, 2024 14:16
@cgarciae cgarciae changed the title [nnx] pytree are trees [nnx] Pytrees are Trees Mar 19, 2024
Base automatically changed from nnx-improve-graph-update to main March 19, 2024 17:05
@cgarciae cgarciae force-pushed the nnx-pytree-are-trees branch from 4aa5ed2 to e854b48 Compare March 19, 2024 17:27
@codecov-commenter
Copy link

codecov-commenter commented Mar 19, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 59.50%. Comparing base (8220154) to head (7317d9b).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3768      +/-   ##
==========================================
+ Coverage   59.49%   59.50%   +0.01%     
==========================================
  Files         101      101              
  Lines       12623    12595      -28     
==========================================
- Hits         7510     7495      -15     
+ Misses       5113     5100      -13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@cgarciae cgarciae force-pushed the nnx-pytree-are-trees branch from e854b48 to 7317d9b Compare March 19, 2024 20:34
@@ -113,34 +113,25 @@ def init(self, node: Node, items: tuple[tuple[str, Leaf], ...]):


@dataclasses.dataclass(frozen=True)
class ImmutableNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe use PyTree instead of Pytree to match the capitalization in JAX?

@copybara-service copybara-service bot merged commit 12b919c into main Mar 21, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-pytree-are-trees branch March 21, 2024 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants