Releases: google/flax
Releases · google/flax
Version 0.10.2
What's Changed
- Add
nnx.fori_loop
by @IvyZX in #4353 - Linesearch (and lbfgs) support by @jlperla in #4351
- Upgrade Flax NNX Haiku Linen migration doc by @8bitmp3 in #4200
- Fix PRNG handling in
nn.jit
undernn.scan
. by @copybara-service in #4359 - support passing arguments directly to the struct.dataclass decorator by @copybara-service in #4275
- Avoid assert_array_equal for PRNG keys. by @copybara-service in #4363
- [nnx] support pure dicts by @cgarciae in #4352
- [nnx] add data parallel toy example by @cgarciae in #4354
- Add logical axis global context support for NNX by @IvyZX in #4350
- [nnx] fix ToLinen kwargs by @copybara-service in #4270
- [nnx] use HashableMapping instead of FrozenDict by @cgarciae in #4376
- [nnx] fix while_loop/fori_loop bug when sharing references by @cgarciae in #4379
- Add
flax.nnx.eval_shape
docstring by @8bitmp3 in #4374 - Setup the flaxlib in C++, using Meson and Nanobind. by @copybara-service in #4380
- Add
flax.nnx.remat
docstring by @8bitmp3 in #4373 - [nnx] add checkify by @cgarciae in #4381
- Lint flax.nnx.while_loop docstring by @8bitmp3 in #4371
- Lint flax.nnx.fori_loop docstring by @8bitmp3 in #4370
- [nnx] add some optimizations to graph.py by @cgarciae in #4377
- update version to 0.10.2 by @cgarciae in #4387
New Contributors
Full Changelog: v0.10.1...v0.10.2
Version 0.10.1
What's Changed
- Add Flax NNX GraphDef docstring by @8bitmp3 in #4302
- Flesh out the Haiku/Flax guide by @IvyZX in #4305
- [nnx] improve mnist tutorial by @cgarciae in #4316
- Update Flax Evolution from Linen to NNX guide by @8bitmp3 in #4289
- [nnx] try casting integers keys in State.replace_by_pure_dict by @cgarciae in #4317
- Fixed nnx examples bad links in the README.md by @vfdev-5 in #4282
- Fix philosophy link by @jorisSchaller in #4313
- [nnx] add gemma notebook by @cgarciae in #4075
- [nnx] improve init_cache docs by @cgarciae in #4291
- remove markdown from section titles by @cgarciae in #4322
- Avoid depending on JAX internals, which are about to change. by @copybara-service in #4326
- Remove outdated compatibility code. by @jakevdp in #4324
- fix ruff complaints by @levskaya in #4331
- Remove GeGLU activation function and golden tests. by @copybara-service in #4303
- Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False by @copybara-service in #4314
- Avoid assert_array_equal on PRNG keys. by @jakevdp in #4332
- Fix typos in Flax NNX Migrating from Haiku to Flax by @8bitmp3 in #4337
- Add API reference for flax.nnx.nn and improve landing page by @IvyZX in #4338
- [nnx] improve transforms guide by @cgarciae in #4333
- [nnx] cleanup gemma notebook by @cgarciae in #4334
- Remove non-lazy RNG compat mode and flag from flax. by @copybara-service in #4339
- [nnx] fix custom_vjp by @cgarciae in #4306
- Define model surgery in docs by @8bitmp3 in #4349
- [nnx] update State and variables docstrings by @cgarciae in #4346
- Add NNX transforms
nnx.while_loop
andnnx.switch
by @IvyZX in #4343 - update version to v0.10.1 by @cgarciae in #4345
New Contributors
Full Changelog: v0.10.0...v0.10.1
Version 0.10.0
What's Changed
- [nnx] clear nnx basics pip logs by @cgarciae in #4149
- Support linen <-> nnx metadata box converging in
nnx.bridge
by @IvyZX in #4145 - Add nnx bridge API reference to site by @IvyZX in #4158
- [nnx] use jax-style transforms API in nnx_basics by @cgarciae in #4155
- [nnx] improve nnx.scan in_axes/out_axes by @cgarciae in #4157
- Support direct quantization for FP8 matmul by @wenscarl in #3922
- Upgrade Flax NNX Model Surgery by @8bitmp3 in #4135
- [nnx] add more Variable proxy methods by @cgarciae in #4170
- [nnx] disallow Array leaves by @copybara-service in #4172
- Internal change by @copybara-service in #4176
- [nnx] improve landing page and nnx_basics messaging by @cgarciae in #4168
- Fixes a small bug in flax.linen.share_scope, where the scopes of children of the module being merged that were created before setup(),were not being updated to point to the new scope, and so they would end up staying under the original tree. by @copybara-service in #4150
- Move all NNX content up a level to be equal with Linen, to make python packaging more consistent. by @copybara-service in #4177
- Add a guide for
nnx.bridge
by @IvyZX in #4171 - [nnx] improve Optimizer metadata propagation by @cgarciae in #4180
- [nnx] enable sharding transformation on integer prefixes by @cgarciae in #4185
- Support linen.LogicallyPartitioned <-> nnx.Variable by @IvyZX in #4161
- Clean up axis hooks in
nnx.Variable
by @IvyZX in #4189 - Merge nnx.errors to flax.errors by @IvyZX in #4186
- [nnx] optimize jit by @cgarciae in #4191
- Split documentation for Linen and NNX by @cgarciae in #4192
- Partially revert #4192 which sets back a bunch of previous merged pushes. by @copybara-service in #4201
- Align bridge variable tree structures by @IvyZX in #4194
- [NNX site] Fix landing page and banner phrasing and add examples page by @IvyZX in #4202
- shorten banners by @cgarciae in #4206
- Add trimmed Linen to NNX guide by @IvyZX in #4209
- Minor documentation fixes for AxisMetadata. by @copybara-service in #4178
- fix tests for numpy 2.0 compatibility by @copybara-service in #4215
- Forward all arguments when using nnx.transforms.deprecated.scan as a decorator. by @copybara-service in #4208
- [nnx] add transforms guide by @cgarciae in #4197
- [nnx] fix transforms guide by @cgarciae in #4223
- Flax NNX GSPMD guide by @IvyZX in #4220
- Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change. by @copybara-service in #4225
- [nnx] add Randomness guide by @cgarciae in #4216
- Add pure dict conversion util functions to nnx.State. by @IvyZX in #4230
- [nnx] Simplify traversal by @cgarciae in #4205
- Fix false positive tracer leaks in flax library. by @copybara-service in #4232
- [nnx] add flaxlib by @copybara-service in #4235
- [nnx] improve docs by @cgarciae in #4236
- point nnx banner to flax-linen by @cgarciae in #4237
- update banners by @cgarciae in #4238
- Fix scale dtype and refactor q_dot_dq by @wenscarl in #4229
- update banners by @cgarciae in #4241
- Add redirects for Linen guide links in the NNX site scope. by @IvyZX in #4242
- Internal change by @copybara-service in #4243
- Copybara import of the project: by @copybara-service in #4245
- Update Flax NNX Scale Up SPMD guide by @8bitmp3 in #4239
- Upgrade Flax NNX basics doc by @8bitmp3 in #4173
- Improve landing page, glossary and misc by @IvyZX in #4244
- Nitting and adding links by @8bitmp3 in #4248
- enable doctest on notebooks by @cgarciae in #4250
- Update index.rst by @ariG23498 in #4251
- Add NNX checkpointing guide by @IvyZX in #4249
- Add checkpointing guide to website index. by @copybara-service in #4263
- Update to Flax NNX Transforms doc by @8bitmp3 in #4264
- Add why nnx by @cgarciae in #4240
- [nnx] add cloudpickle support by @cgarciae in #4253
- Fix typo:
impost
toimport
by @Vilin97 in #4256 - [nnx] revive TrainState toy example by @cgarciae in #4226
- [nnx] add custom_vjp to docs by @cgarciae in #4266
- remove flax-nnx urls by @cgarciae in #4267
- Add flatten to nnx.graph autosummary in graph.rst by @8bitmp3 in #4255
- [nnx] add FSDP toy example with custom optimizer by @cgarciae in #4183
- Update Flax NNX Landing Page by @8bitmp3 in #4274
- Update to Flax NNX Model Surgery by @8bitmp3 in #4276
- Update Why Flax NNX guide by @8bitmp3 in #4262
- Update to Flax NNX MNIST tutorial by @8bitmp3 in #4277
- [nnx] improve randomness guide by @cgarciae in #4281
- Remove notebook exceptions in
docs_nnx
doctest by @IvyZX in #4285 - [nnx] add PrefixMapping by @cgarciae in #4278
- [nnx] state filters by @cgarciae in #4288
- Fix devcontainer setup by @jorisSchaller in #4299
- Ugrade Flax NNX Checkpointing guide by @8bitmp3 in #4294
- Update Flax NNX Scale Up guide by @8bitmp3 in #4296
- Porting RNN from Linen to NNX by @zinccat in #4272
- Update Flax NNX Glossary by @8bitmp3 in #4284
- update version to 0.10.0 by @cgarciae in #4292
New Contributors
- @ariG23498 made their first contribution in #4251
- @Vilin97 made their first contribution in #4256
- @jorisSchaller made their first contribution in #4299
- @zinccat made their first contribution in #4272
Full Changelog: v0.9.0...v0.10.0
v0.9.0
What's Changed
- Add NNX surgery guide by @IvyZX in #4005
- Port gemma/transformer to NNX by @copybara-service in #4019
- upgrade python to 3.10 + use pyupgrade by @cgarciae in #4038
- [nnx] add Using Filters guide by @cgarciae in #4028
- v0.8.6 by @cgarciae in #4040
- allow imagenet training profiling to be disabled in config by @copybara-service in #4043
- [nnx] LoRAParam inherits from Param by @cgarciae in #3988
- [linen] allows multiple compact methods by @cgarciae in #3808
- Added support of NANOO fp8. by @wenchenvincent in #3993
- Add functool.wraps() annotation to flax.nn.jit. by @copybara-service in #4051
- Fix typo in
nnx_basics
doc by @rajasekharporeddy in #4047 - [nnx] fix Variable overloads and add shape/dtype properties by @cgarciae in #4049
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4039
- [nnx] stabilize unsafe_pytree by @cgarciae in #4030
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4055
- [NVIDIA] Rename fp8 custom dtype to
fp32_max_grad
by @kaixih in #3984 - [nnx] fix mnist_tutorial colab link by @cgarciae in #4063
- [nnx] fix Accuracy on eager mode by @cgarciae in #4065
- Update orbax_upgrade_guide.rst for async checkpointing usage examples by @kaushaladiti-2802 in #4036
- Re-enable some tests after Python 3.9 is dropped by @IvyZX in #4067
- Rename
nnx.compat
tonnx.bridge
by @IvyZX in #4066 - [nnx] improve mnist tutorial by @cgarciae in #4070
- Modify Flax checkpointing in preparation for cl/650338576. by @copybara-service in #4072
- Remove some outdated backward-compatibility code. by @copybara-service in #4068
- [NVIDIA] Add a user guide for fp8 by @kaixih in #4076
- [nnx] add extract APIs by @cgarciae in #4078
- [example]: remove lm1b useless parallism rules by @knightXun in #4077
- [nnx] improve filters guide by @cgarciae in #4059
- [nnx] add call by @cgarciae in #4004
- Ignore Orbax warning in deprecated
flax.training.checkpoints.py
to unbreak head doctest by @IvyZX in #4092 - fix mypy failures due tu numpy update by @cgarciae in #4098
- [linen] generalize transform caching by @copybara-service in #4057
- [linen] fold rngs on jit to improve caching by @copybara-service in #4064
- Add shape-based lazy init to
LinenToNNX
(prevLinenWrapper
) by @IvyZX in #4081 - [nnx] add reseed by @cgarciae in #4099
- [nnx] add split/merge_inputs by @cgarciae in #4084
- Perform shape checks for self.param AFTER unboxing by @danielwatson6 in #4079
- fix restore_checkpoint example in docstring by @copybara-service in #4101
- [numpy] Fix users of NumPy APIs that are removed in NumPy 2.0. by @copybara-service in #4104
- set profile_duration_ms = None as in periodic_actions there's default value for both num_profile_steps and profile_duration_ms, and the profile stopping condition is when both num_profile_steps and profile_duration_ms are satisfied, so setting profile_duration_ms=None so that the passed num_profile_steps value gets used by @copybara-service in #4096
- [linen] add share_scope by @cgarciae in #4102
- Allow metadata pass-through in flax.struct.field by @cool-RR in #4056
- avoid mixing
einsum_dot_general
andeinsum
argument by specifying them explicitly in the caller. by @copybara-service in #4115 - Add logging to track deprecated codepaths. by @copybara-service in #4121
- [pmap no rank reduce cleanup]: When flipping the by @copybara-service in #4125
- Add NNXToLinen wrapper to
nnx.bridge
by @IvyZX in #4126 - Switch NNX to use Treescope instead of Penzai. by @copybara-service in #4132
- Add GroupNorm to NNX normalization layers by @treigerm in #4095
- [nnx] fix initializing propagation by @cgarciae in #4134
- add JAX-style NNX Transforms FLIP by @cgarciae in #4108
- Fix
_ParentType
annotation by @dcharatan in #4120 - add uv.lock file by @copybara-service in #4139
- use uv package manager by @cgarciae in #4136
- More testing and misc fixes on wrappers by @IvyZX in #4137
- Fix link to orbax documentation by @cool-RR in #4123
- [nnx] experimental transforms by @cgarciae in #3963
- [nnx] improve docs by @cgarciae in #4141
- remove repeated license headers by @cgarciae in #4148
- update Flax to version 0.9.0 by @copybara-service in #4147
New Contributors
- @wenchenvincent made their first contribution in #3993
- @rajasekharporeddy made their first contribution in #4047
- @kaushaladiti-2802 made their first contribution in #4036
- @knightXun made their first contribution in #4077
- @danielwatson6 made their first contribution in #4079
- @cool-RR made their first contribution in #4056
- @treigerm made their first contribution in #4095
- @dcharatan made their first contribution in #4120
Full Changelog: v0.8.5...v0.9.0
v0.8.5
What's Changed
- v0.8.5 by @cgarciae in #3941
- [nnx] improve vmap axis size detection by @cgarciae in #3947
- Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
- [nnx] fix nnx_basics dependencies by @cgarciae in #3942
- Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
- updated rng guide by @chiamp in #3912
- upgraded haiku guide to include NNX by @chiamp in #3923
- parameterized NNX transforms tests by @chiamp in #3906
- Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
- fix HEAD by @chiamp in #3960
- Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
- Adding Welford metric. by @copybara-service in #3959
- Modify Welford metric to return mean value. by @copybara-service in #3970
- [nnx] make State generic by @cgarciae in #3964
- updated NNX nn docstrings by @chiamp in #3972
- make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
- updated
nnx.module
docstrings by @chiamp in #3966 - updated
nnx.Conv
andnnx.ConvTranspose
by @chiamp in #3974 - updated
nnx.graph
docstrings by @chiamp in #3958 -
- Adds
pmap
andPmap
.static_broadcasted_argnums
,donate_argnums
, andglobal_arg_shapes
are not yet supported. by @copybara-service in #3978
- Adds
- Fixes for batch norm docs by @jkarwowski in #3982
- fix deprecation warning by @chiamp in #3981
- updated NNX
rnglib
docstring by @chiamp in #3980 - updated
nnx.training
by @chiamp in #3975 - updated
nnx.variables
docstrings by @chiamp in #3986 - [nnx] vectorize vmap split counts by @cgarciae in #3989
- added
wrt
option tonnx.Optimizer
by @chiamp in #3983 - Added
nnx.graph.iter_children
by @chiamp in #3991 - [nnx] fix vmap by @copybara-service in #3995
- Fix head pytest breakage by @IvyZX in #4006
- Helper function for loading params from a linen module by @copybara-service in #4012
- Port gemma/layers to NNX by @copybara-service in #4013
- [nnx] fix grad by @cgarciae in #4007
- [nnx] add PathContains Filter by @cgarciae in #4011
- Support Python 3.9 by @copybara-service in #4018
- Port gemma/modules to NNX by @copybara-service in #4014
- Internal change to fix current head CI by @copybara-service in #4017
- Unpin the Orbax pip version. by @copybara-service in #4024
- Fix Gemma test to unbreak head by @IvyZX in #4025
- Fix pickling of exceptions by @sanderland in #4002
- Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
- CI: add test run against oldest supported jax version by @jakevdp in #3996
- Make
force_fp32_for_softmax
arg inMultiHeadDotProductAttention
useful. by @copybara-service in #4029
New Contributors
- @mcsmart76 made their first contribution in #3953
- @jkarwowski made their first contribution in #3982
- @sanderland made their first contribution in #4002
Full Changelog: v0.8.4...v0.8.5
v0.8.4
What's Changed
- fixed codecov by @chiamp in #3895
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3880
- Share nnx node registry between threads by @NeilGirdhar in #3901
- fixed
jnp.clip
deprecation by @chiamp in #3905 - Added three tab option to sphinx directive
codediff
and added testing for first tab by @chiamp in #3847 - Add support for
jax.sharding.PartitionSpec.UNCONSTRAINED
in logical specification by @copybara-service in #3902 - [nnx] fix mypy and pytype by @cgarciae in #3894
- [nnx] fix iter_nodes by @cgarciae in #3889
- [nnx] Sequential uses regular list by @cgarciae in #3909
- [nnx] add ConvTranspose by @cgarciae in #3908
- [nnx] add Module pytree_experimental static test by @cgarciae in #3864
- Added docstring for
Module.scope.path
by @chiamp in #3913 - [linen] test jit caching with state updates by @cgarciae in #3900
- v0.8.4 by @cgarciae in #3891
- [linen] enable separate initializers for out layer in MultiHeadDotProductAttention by @cgarciae in #3835
- [nnx] cleanup graph by @cgarciae in #3915
- [nnx] fix bugs by @cgarciae in #3925
- Replace deprecated
jax.tree_*
functions withjax.tree.*
by @copybara-service in #3926 - [nnx] Object refactor by @cgarciae in #3910
- [nnx] add iter_graph by @cgarciae in #3919
- [nnx] add compat by @cgarciae in #3921
- [nnx] transforms refactor by @cgarciae in #3927
- added equivalence test for
nnx.ConvTranspose
by @chiamp in #3934 - added equivalence test for
nnx.Sequential
by @chiamp in #3935 - [NNX] Add
LoRA
andLoRALinear
to NNX by @IvyZX in #3929 - [nnx] fix substate mutability by @cgarciae in #3932
- [nnx] improve update context by @cgarciae in #3933
- [nnx] move out of experimental by @cgarciae in #3936
Full Changelog: v0.8.3...v0.8.4
v0.8.3
What's Changed
- Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
- removed getattr/setattr unboxing magic from
nnx.Pytree
by @chiamp in #3743 - added Einsum layer to NNX by @chiamp in #3741
- Make
TrainState
'sstep
possibly jax.Array. This makesreplicate
valid for type checking. by @copybara-service in #3763 - v0.8.3 by @cgarciae in #3758
- [nnx] fix demo notebook by @cgarciae in #3744
- added nnx api reference by @chiamp in #3762
- updated rng docstring for init, apply and make_rng by @chiamp in #3765
- use note box in make_rng docstring by @cgarciae in #3767
- [nnx] improved graph update mechanism by @cgarciae in #3759
- use note box in docstrings by @chiamp in #3769
- Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
- Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
- Minor doc improvements by @canyon289 in #3588
- added MGU
reset_gate
test by @chiamp in #3773 - [nnx] Pytrees are Trees by @cgarciae in #3768
- Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
- fix tabulate on norm wrappers by @chiamp in #3772
- Add
kw_only
struct.dataclass test by @chiamp in #3651 - extended
PyTreeNode
to take dataclass kwargs by @chiamp in #3785 - [nnx] Arrays are state by @cgarciae in #3791
- [nnx] add GraphNode base class by @cgarciae in #3790
- [nnx] jit accepts many Modules by @cgarciae in #3783
- Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
- Expose
nnx.GraphNode
by @chiamp in #3796 - [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
- [nnx] TrainState uses struct by @cgarciae in #3788
- [nnx] split returns graphdef first by @cgarciae in #3794
- Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
- Add
nnx.training
by @chiamp in #3782 - [nnx] non-str State keys by @cgarciae in #3802
- [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
- [nnx] simplify readme by @cgarciae in #3805
- [nnx] Fix nnx basics by @cgarciae in #3812
- [nnx] grad accepts argnums by @cgarciae in #3798
- [nnx] improve toy examples by @cgarciae in #3813
- [nnx] expose Sequential by @cgarciae in #3814
- [nnx] Rng Variable tags by @cgarciae in #3807
- [nnx] remove copy in graph unflatten by @cgarciae in #3804
- fixed optax guide links and docstring typos by @chiamp in #3789
- added dropout broadcast test by @chiamp in #3776
- relaxed
grads
kwarg forOptimizer.update
by @chiamp in #3818 - added
tree_map
deprecation warning filter by @chiamp in #3828 - updated
tree_map
by @chiamp in #3823 - added NNX vs JAX transformations guide by @chiamp in #3819
- Updated NNX MNIST tutorial by @chiamp in #3810
- [nnx] add Dropout.rngs by @cgarciae in #3815
- removed autosummary from linen docs by @chiamp in #3792
- Fix cloudpickle sentinel cloning by @cgarciae in #3825
- [nnx] remove pytreelib by @cgarciae in #3816
- [nnx] fix nnx_basics by @cgarciae in #3839
- [linen] fix DenseGeneral init by @cgarciae in #3834
- [nnx] jit constrain object state by @cgarciae in #3817
- Copybara import of the project: by @copybara-service in #3857
- Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
- RNNCellBase refactor FLIP by @cgarciae in #3099
- [nnx] Some small documentation suggestions. by @gnecula in #3861
- updated nnx dropout by @chiamp in #3841
- Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
- Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
- added nnx api reference link by @chiamp in #3871
- option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
- allow custom dot_general for einsum. by @copybara-service in #3884
- [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
- updated
robots.txt
by @chiamp in #3886 - fixed autosummary links by @chiamp in #3887
- Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
- [nnx] v0.1 by @cgarciae in #3876
Full Changelog: v0.8.2...v0.8.3
v0.8.2
What's Changed
- Add +1 to version after 0.8.1 release by @IvyZX in #3684
- fixed rng guide outputs by @chiamp in #3685
- enforce mask kwarg in norm layers by @chiamp in #3663
- added kwargs to self.param and self.variable by @chiamp in #3675
- added nnx normalization tests by @chiamp in #3689
- added NNX init_cache docstring example by @chiamp in #3688
- added nnx attention equivalence test by @chiamp in #3687
- Fix bug that assumed frozen-dict keys were strings. by @copybara-service in #3692
- added nnx rmsnorm by @chiamp in #3691
- updated nnx compute_stats by @chiamp in #3693
- fixed intercept_methods docstring by @chiamp in #3694
- [nnx] Add Sphinx Docs by @cgarciae in #3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in #3703
- added default params rng to .apply by @chiamp in #3698
- [nnx] add partial_init by @cgarciae in #3674
- make make_rng default to 'params' by @chiamp in #3699
- Add SimpleCell. by @carlosgmartin in #3697
- fix Module.module_paths docstring by @cgarciae in #3709
- Guarantee the latest JAX version on CI by @cgarciae in #3705
- Replace deprecated API
jax.tree_map
by @copybara-service in #3715 - Use
jax.tree_util.tree_map
instead of deprecatedjax.tree_map
. by @copybara-service in #3714 - [nnx] simplify readme by @cgarciae in #3707
- [nnx] add demo.ipynb by @cgarciae in #3680
- Fix Tabulate's compute_flops by @cgarciae in #3721
- [nnx] simplify TraceState by @cgarciae in #3724
- Add broadcast of
strides
andkernel_dilation
tonn.ConvTranspose
by @IvyZX in #3731 - [nnx] Fix State.sub by @cgarciae in #3704
- [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in #3722
- [nnx] explicit Variables by @cgarciae in #3720
- Improves fingerprint definition for Modules in nn.jit. by @copybara-service in #3736
- Flax: avoid key reuse in tests by @copybara-service in #3740
- added Einsum layer by @chiamp in #3710
- nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in #3737
- [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in #3623
- removed nnx dataclass by @chiamp in #3742
- [nnx] cleanup graph_utils by @cgarciae in #3728
- Fix doctest and unbreak head by @IvyZX in #3753
- [nnx] add pytree support by @cgarciae in #3732
- fixed intercept_methods docstring by @chiamp in #3752
- Add ConvLSTMCell to docs. by @carlosgmartin in #3712
- [nnx] remove flagslib by @cgarciae in #3733
- Fix tests after applying JAX key-reuse checker. See: by @copybara-service in #3748
Full Changelog: v0.8.1...v0.8.2
Version 0.8.1
What's Changed
- bump version number to 0.8.1 by @chiamp in #3649
- Bump pillow from 10.0.1 to 10.2.0 in /examples/vae by @dependabot in #3641
- fixed docstring by @chiamp in #3643
- Add explicit control over frozen/slots setting in flax.struct.dataclass by @copybara-service in #3645
- make Sequential.call compact by @copybara-service in #3647
- add Module.module_paths by @cgarciae in #3654
- added rng_guide by @chiamp in #3497
- Replacing jax.tree_util.tree_map with mapping over leafs. by @copybara-service in #3658
- Copybara import of the project: by @copybara-service in #3659
- added InstanceNorm by @chiamp in #3652
- add Module.module_paths by @copybara-service in #3660
- added norm equivalence tests by @chiamp in #3662
- updated nowrap docstring by @chiamp in #3661
- Add module_paths method to docs by @cgarciae in #3657
- add default make_rng by @chiamp in #3669
- renamed channel_axes to feature_axes in InstanceNorm by @chiamp in #3667
- added flax.typing by @chiamp in #3624
- changed kwargs to actual key-word args by @chiamp in #3562
- updated docs and docstrings by @chiamp in #3670
- re-added linen_intro by @chiamp in #3672
- add compact_name_scope v3 by @cgarciae in #3646
- Release 0.8.1 by @IvyZX in #3682
Full Changelog: v0.8.0...v0.8.1
v0.8.0
What's Changed
- bump version number by @levskaya in #3446
- Add merge / finalize step when using OCDBT driver. Files will be first written to per-process subdirectories, which are later copied by reference to the main directory before the checkpoint is finalized. by @copybara-service in #3426
- fixed quickstart by @chiamp in #3451
- [NVIDIA] Update the algorithm to compute fp8 scales by @kaixih in #3441
- added pre-commit hook that sort imports and formats by @chiamp in #3455
- restructured doc folders by @chiamp in #3434
- Forked a subset of JAX configuration APIs by @superbobry in #3448
- Fix Module.clone in deepclone mode for internal usage. by @levskaya in #3459
- Add user-friendly module copy method. by @levskaya in #3461
- Add simple argument-only lifted nn.grad function. by @levskaya in #3463
- exempt a jax.config deprecation warning by @levskaya in #3465
- Clean up pyproject.toml. by @levskaya in #3468
- Allow for fast accumulation selection for FP8 GEMM by @wenscarl in #3416
- re-added quickstart guide by @chiamp in #3471
- fixed tabulate docstring by @chiamp in #3452
- Add NNX by @cgarciae in #3218
- Bump pillow from 9.5.0 to 10.0.1 in /examples/vae by @dependabot in #3390
- updated attention_test by @chiamp in #3454
- [nnx] Improve docs by @cgarciae in #3478
- added example docstrings by @chiamp in #3453
- fix nn.value_and_grad by implementing directly in core by @levskaya in #3479
- Add dataset loading guide (Issue #2116) by @VictorPrins in #3450
- [nnx] Add support for python container types by @cgarciae in #3486
- remove SelfAttention test and warning filter by @chiamp in #3470
- disabled ruff formatter by @chiamp in #3482
- adding doctest to .rst files by @chiamp in #3481
- changed pip installs to use quotes by @chiamp in #3477
- added enum support for tabulate by @chiamp in #3485
- fix bug in optimizer-api.md by @zhaoyang-0204 in #3462
- removed selfattention from doctest by @chiamp in #3489
- [nnx] Add missing import on why.ipynb by @cgarciae in #3503
- [nnx] switch to nested State representation by @cgarciae in #3502
- Improved Rigor of
PReLU
Test by @Micky774 in #3498 - added geglu activation and tests by @HMUNACHI in #3512
- [nnx] Add LinearGeneral and MultiHeadAttention by @cgarciae in #3487
- Add NNX/Linen consistency test for
Embed
layer by @Micky774 in #3513 - Add NNX/Linen API consistency test for
Conv
layer by @Micky774 in #3511 - Prevent crash in dataclasses with no-init params by @NeilGirdhar in #3514
- [nnx] Variable referece sharing by @cgarciae in #3516
- Added NNX/Linen API consistency test for
Linear/Dense
layer by @Micky774 in #3509 - Add missing mask argument to LayerNorm, RMSNorm, and GroupNorm. by @carlosgmartin in #3510
- [nnx] Fix graph_utils bug by @cgarciae in #3518
- remove deprecated normalize function by @chiamp in #3531
- Reduced number of parameterizations for
Conv
NNX/Linen consistency test by @Micky774 in #3526 - Ensure that
_hashable_filter
does not convert strings to a tuple of letters by @copybara-service in #3533 - added sow attention weights by @chiamp in #3529
- Fix scan out_axes by @cgarciae in #3540
- updated embed docstring by @chiamp in #3539
- add test_scan_negative_axes by @cgarciae in #3542
- add module methods to api docs by @chiamp in #3544
- fixed double backquote code font by @chiamp in #3545
- add nnx conv support for int kernel size by @chiamp in #3537
- added sow attention weights to NNX by @chiamp in #3548
- changed
return_weights
tosow_weights
for attention layer by @chiamp in #3550 - format linen_linear_test.py by @chiamp in #3553
- re-factored features arg by @chiamp in #3554
- updated NNX readme by @chiamp in #3556
- Disable ruff sort imports by @cgarciae in #3560
- Add StateVariablesMapping by @cgarciae in #3523
- add kwargs support for nn.jit by @copybara-service in #3559
- [nnx] Fix readme install instruction by @cgarciae in #3565
- implement Rng.getattr by @cgarciae in #3547
- [nnx] add qkv_features back to MHA by @cgarciae in #3566
- updated readme by @chiamp in #3563
- fixed typo by @chiamp in #3561
- Raise an error for a bad key type by @NeilGirdhar in #3527
- re-factored nnx initializers by @chiamp in #3555
- [nnx] Add complex test with scan + batchnorm + dropout by @cgarciae in #3567
- [nnx] Add interacting with JAX section to README by @cgarciae in #3573
- expose ones and zeros initializers by @chiamp in #3574
- Fix promotion bug in MultiHeadDotProductAttention: by @giovannic in #3571
- fixed error doc formatting by @chiamp in #3587
- [nnx] Improve spmd by @cgarciae in #3580
- [nnx] improve graph_utils._set_key_tuple by @cgarciae in #3592
- [nnx] Fix variable unflatten by @cgarciae in #3578
- [nnx] add open in colab button to why nnx by @cgarciae in #3596
- [nnx] Export missing symbols by @cgarciae in #3583
- [nnx] flaglib add get overloads by @cgarciae in #3582
- Fix type in NNX readme by @shoyer in #3591
- [nnx] add submodule iterator by @cgarciae in #3581
- [nnx] delete flaglib duplicated copyright comment by @cgarciae in #3600
- fixed NNX decode and dynamic slicing by @chiamp in #3576
- [nnx] cleanup CallableProxy by @cgarciae in #3608
- [nnx] improve runtime flags by @cgarciae in #3607
- fixed broken links on quick start guide by @chiamp in #3610
- added multiheadattention alias by @chiamp in #3572
- Rollback of Copybara import of the project: by @copybara-service in #3612
- add missing docs for module functions by @cgarciae in #3619
- fix lm1b data sharding by @cgarciae in #3620
- improve embed by @jianyizh in #3590
- disable ruff linter by @chiamp in #3625
- Add compact_name_scope decorator by @cgarciae in #3621
- Copybara import of the project: by @copybara-service in #3638
- added BatchApply by @chiamp in #3634
- add compact_name_scope v2 by @copybara-service in #3640
- add compact_name_scope v2 by @copybara-service in #3642
- release 0.8.0 by @chiamp in #3644
New Contributors
- @superbobry made their first contribution in #3448
- @VictorPrins made their first contribution in #3450
- @zhaoyang-0204 made their first contribution in #3462
- @Micky774 made their first contribution in #3498
- @HMUNACHI made their first contribution in #3512
- @carlosgmartin made their first contribution in #3510
- @giovannic made their first contribution in https://gith...