This repository has been archived by the owner on Feb 26, 2023. It is now read-only.
Releases: cgarciae/treex
Releases · cgarciae/treex
0.6.1
0.6.0
Shape Inference + @compact
support 🎉
Changes
- Adds the
tx.next_key()
function and thetx.rng_key()
context manager. Module.init
new has the following behavior:- Accepts an optional
inputs
argument and runs the forward method if given. - Set the given
key
in the context sotx.next_key()
can be used. - Accepts a
call_method: str
which defines the method to call,"__call__"
used by default.
- Accepts an optional
Module
s will now be initialized if constructed within@tx.compact
functions when called byinit
.- Adds
@tx.compact_module
decorator that can turn any function into a Module with a compact__call__
as the decorated function. - New
Crossentropy
loss that generalizesBinaryCrossentropy
,CategoricalCrossentropy
andSparseCategoricalCrossentropy
. - New
Flatten
layers.
0.5.0
Major Changes
- Treex now depends on Treeo to generate its Pytree.
update
is now calledmerge
consistent with Treeo, it also avoids name clashes with Optimizer.- Kinds are no longer annotations, instead uses Tree's kind system. So this annotation
w: tx.Parameter[jnp.ndarray]
becomes
w: jnp.ndarray = tx.Parameter.node()
0.4.0
Changes
Optimizer
now flattens its params to be agnostic to the static components of the pytree.- Generic types containing
TreePart
s are no longer valid type annotation as types likeTuple[int, tx.State[int]]
make it appear as if the first element of the tuple where static and the second dynamic, when in fact Treex would treat to whole field as dynamic. Now your only option istx.State[Tuple[int, int]]
. - Adds the
tx.Hashable
class to wrap non-hashable types like numpy or jax arrays when you want to use them in static fields of aTreeObject
. - Adds
FlaxModule
: can wrap any Flax Module into a Treex Module. tabulate
now accepts asample_input
and will show theinput
andoutput
columns.- Refactors a lot of the functional API.
- Introduces
.freeze()
,.unfreeze()
and.frozen
similar to train/eval/training. - Updates BatchNorm and Dropout to leverage
.frozen
.
0.3.0
Changes
TreePart
s are now generic and statically behave likeUnion
.filter
now also accepts predicates.Optimizer.update
was renamed toapply_updates
.- Expanded
TreePart
hierarchy. - Added
RngSeq
Module (generatesPRNGKey
s on demand). TreeObject
now has a metaclass that checks thatsuper().__init__()
is always called.- Fields with
TreeObject
values are now automatically annotated if an annotation is not provided by the user. filter
now also accepts predicates of the typeFieldInfo -> bool
.- Added the
Static
annotation for when you want a field to be explicitly marked as a static part of the Pytree. This is useful if the field will hold aTreeObject
but you don't want it to be a child of the Pytree e.g.ignored_linear: tx.Static[tx.Linear]