Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow manually specifying the rng key for Dropout #3114

Merged
merged 1 commit into from
May 24, 2023

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented May 24, 2023

What does this PR do?

Fixes #3115.

  • Adds an optional rng argument to Dropout.__call__ so users can manually specify a PRNGKey if more fine grained control is required.
  • Improves typing for PRNGKey type alias (renamed to KeyArray to match JAX naming).

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Global configuration options for Flax.
r"""Global configuration options for Flax.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why I needed to add this but Python was refusing to run this line due to a strange character.

@cgarciae cgarciae requested a review from jheek May 24, 2023 15:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Force no split in make_rng
2 participants