From 385ca7384058a81c726621259cffead10292f8fb Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 2 Jul 2024 16:39:16 +0100 Subject: [PATCH] [nnx] fix Variable overloads and add shape/dtype properties --- docs/nnx/nnx_basics.ipynb | 103 ++++++++++++++++++++++--------- docs/nnx/nnx_basics.md | 10 +-- flax/nnx/nnx/variables.py | 127 ++++++++++++++++++++++++++++++-------- 3 files changed, 179 insertions(+), 61 deletions(-) diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index ae12ce522e..976fa4ac8f 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -19,20 +19,61 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [ "skip-execution" ] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: flax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.8.5)\n", + "Requirement already satisfied: penzai in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.1.3)\n", + "Requirement already satisfied: numpy>=1.22 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.26.4)\n", + "Requirement already satisfied: jax>=0.4.27 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.4.31.dev20240621+0428a1509)\n", + "Requirement already satisfied: msgpack in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.0.8)\n", + "Requirement already satisfied: optax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.2.2)\n", + "Requirement already satisfied: orbax-checkpoint in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.5.20)\n", + "Requirement already satisfied: tensorstore in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.1.63)\n", + "Requirement already satisfied: rich>=11.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (13.7.1)\n", + "Requirement already satisfied: typing-extensions>=4.2 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (4.12.2)\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (6.0.1)\n", + "Requirement already satisfied: absl-py>=1.4.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (2.1.0)\n", + "Requirement already satisfied: equinox>=0.11.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (0.11.4)\n", + "Requirement already satisfied: ordered_set>=4.1.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (4.1.0)\n", + "Requirement already satisfied: jaxtyping>=0.2.20 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from equinox>=0.11.3->penzai) (0.2.31)\n", + "Requirement already satisfied: jaxlib<=0.4.31,>=0.4.30 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.30)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.0)\n", + "Requirement already satisfied: opt-einsum in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.9 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (1.14.0)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.18.0)\n", + "Requirement already satisfied: chex>=0.1.86 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from optax->flax) (0.1.86)\n", + "Requirement already satisfied: etils[epath,epy] in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.7.0)\n", + "Requirement already satisfied: nest_asyncio in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0)\n", + "Requirement already satisfied: protobuf in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (3.20.3)\n", + "Requirement already satisfied: toolz>=0.9.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from chex>=0.1.86->optax->flax) (0.12.1)\n", + "Requirement already satisfied: typeguard==2.13.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.3->penzai) (2.13.3)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)\n", + "Requirement already satisfied: fsspec in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (2024.6.0)\n", + "Requirement already satisfied: importlib_resources in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (6.4.0)\n", + "Requirement already satisfied: zipp in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (3.19.2)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], "source": [ "! pip install -U flax penzai" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -62,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -95,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -108,7 +149,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -146,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -166,7 +207,7 @@ " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self):\n", - " self.count.value += 1\n", + " self.count += 1\n", "\n", "counter = Counter()\n", "print(f'{counter.count.value = }')\n", @@ -196,13 +237,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -255,13 +296,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -319,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -383,7 +424,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -396,7 +437,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -454,13 +495,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -477,13 +518,15 @@ " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", " self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))\n", " self.b = nnx.Param(jnp.zeros((dout,)))\n", - " self.count = Count(0)\n", + " self.count = Count(jnp.array(0, dtype=jnp.uint32))\n", "\n", " def __call__(self, x: jax.Array):\n", - " self.count.value += 1\n", - " return x @ self.w.value + self.b.value\n", + " self.count += 1\n", + " return x @ self.w + self.b\n", " \n", "model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n", + "y = model(jnp.ones((1, 3)))\n", + "\n", "nnx.display(model)" ] }, @@ -501,13 +544,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -519,7 +562,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -550,15 +593,15 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "model.count.value = 0\n", - "model.count.value = Array(1, dtype=int32, weak_type=True)\n" + "model.count.value = Array(1, dtype=uint32)\n", + "model.count.value = Array(2, dtype=uint32)\n" ] } ], @@ -621,13 +664,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -639,7 +682,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -667,7 +710,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index 1bc9c503ec..5cf875c3ae 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -95,7 +95,7 @@ class Counter(nnx.Module): self.count = Count(jnp.array(0)) def __call__(self): - self.count.value += 1 + self.count += 1 counter = Counter() print(f'{counter.count.value = }') @@ -279,13 +279,15 @@ class StatefulLinear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - self.count = Count(0) + self.count = Count(jnp.array(0, dtype=jnp.uint32)) def __call__(self, x: jax.Array): - self.count.value += 1 - return x @ self.w.value + self.b.value + self.count += 1 + return x @ self.w + self.b model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0)) +y = model(jnp.ones((1, 3))) + nnx.display(model) ``` diff --git a/flax/nnx/nnx/variables.py b/flax/nnx/nnx/variables.py index 1bed4ded51..1f1c8e84ae 100644 --- a/flax/nnx/nnx/variables.py +++ b/flax/nnx/nnx/variables.py @@ -465,7 +465,15 @@ def on_remove_axis( ) -> V: raise NotImplementedError - # operator overloads + # overloads + @property + def shape(self) -> tuple[int, ...]: + return self.value.shape # type: ignore + + @property + def dtype(self) -> Any: + return self.value.dtype # type: ignore + def __jax_array__(self): return self.value @@ -556,44 +564,109 @@ def __rxor__(self, other) -> A: def __ror__(self, other) -> A: return self.value.__ror__(other) # type: ignore - def __iadd__(self, other) -> A: - return self.value.__iadd__(other) # type: ignore + def __iadd__(self: V, other) -> V: + value = self.value + if hasattr(value, '__iadd__'): + value.__iadd__(other) + else: + self.value = value.__add__(other) + return self - def __isub__(self, other) -> A: - return self.value.__isub__(other) # type: ignore + def __isub__(self: V, other) -> V: + value = self.value + if hasattr(value, '__isub__'): + value.__isub__(other) + else: + self.value = value.__sub__(other) + return self - def __imul__(self, other) -> A: - return self.value.__imul__(other) # type: ignore + def __imul__(self: V, other) -> V: + value = self.value + if hasattr(value, '__imul__'): + value.__imul__(other) + else: + self.value = value.__mul__(other) + return self - def __imatmul__(self, other) -> A: - return self.value.__imatmul__(other) # type: ignore + def __imatmul__(self: V, other) -> V: + value = self.value + if hasattr(value, '__imatmul__'): + value.__imatmul__(other) + else: + self.value = value.__matmul__(other) + return self - def __itruediv__(self, other) -> A: - return self.value.__itruediv__(other) # type: ignore + def __itruediv__(self: V, other) -> V: + value = self.value + if hasattr(value, '__itruediv__'): + value.__itruediv__(other) + else: + self.value = value.__truediv__(other) + return self - def __ifloordiv__(self, other) -> A: - return self.value.__ifloordiv__(other) # type: ignore + def __ifloordiv__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ifloordiv__'): + value.__ifloordiv__(other) + else: + self.value = value.__floordiv__(other) + return self - def __imod__(self, other) -> A: - return self.value.__imod__(other) # type: ignore + def __imod__(self: V, other) -> V: + value = self.value + if hasattr(value, '__imod__'): + value.__imod__(other) + else: + self.value = value.__mod__(other) + return self - def __ipow__(self, other) -> A: - return self.value.__ipow__(other) # type: ignore + def __ipow__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ipow__'): + value.__ipow__(other) + else: + self.value = value.__pow__(other) + return self - def __ilshift__(self, other) -> A: - return self.value.__ilshift__(other) # type: ignore + def __ilshift__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ilshift__'): + value.__ilshift__(other) + else: + self.value = value.__lshift__(other) + return self - def __irshift__(self, other) -> A: - return self.value.__irshift__(other) # type: ignore + def __irshift__(self: V, other) -> V: + value = self.value + if hasattr(value, '__irshift__'): + value.__irshift__(other) + else: + self.value = value.__rshift__(other) + return self - def __iand__(self, other) -> A: - return self.value.__iand__(other) # type: ignore + def __iand__(self: V, other) -> V: + value = self.value + if hasattr(value, '__iand__'): + value.__iand__(other) + else: + self.value = value.__and__(other) + return self - def __ixor__(self, other) -> A: - return self.value.__ixor__(other) # type: ignore + def __ixor__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ixor__'): + value.__ixor__(other) + else: + self.value = value.__xor__(other) + return self - def __ior__(self, other) -> A: - return self.value.__ior__(other) # type: ignore + def __ior__(self: V, other) -> V: + value = self.value + if hasattr(value, '__ior__'): + value.__ior__(other) + else: + self.value = value.__or__(other) + return self def __neg__(self) -> A: return self.value.__neg__() # type: ignore