diff --git a/docs/guides/flax_fundamentals/rng_guide.ipynb b/docs/guides/flax_fundamentals/rng_guide.ipynb index 0cecc3cb45..ae3cd75588 100644 --- a/docs/guides/flax_fundamentals/rng_guide.ipynb +++ b/docs/guides/flax_fundamentals/rng_guide.ipynb @@ -36,29 +36,20 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 1, "metadata": { - "outputId": "bb13d0ba-f908-4746-e4d3-f5916e0670ff", "tags": [ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/bin/sh: line 1: pip: command not found\n" - ] - } - ], + "outputs": [], "source": [ "!pip install -q flax" ] }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -83,9 +74,9 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 4, "metadata": { - "outputId": "ab5e0d3b-2d51-44ee-b823-d152ad1a10b2" + "outputId": "ec904f6b-0e87-4efe-87c4-fea0f8e8ec23" }, "outputs": [ { @@ -101,7 +92,7 @@ " CpuDevice(id=7)]" ] }, - "execution_count": 45, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -119,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -146,9 +137,9 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 6, "metadata": { - "outputId": "8304fbc2-ab0a-42f0-9d4c-c88a74012e83" + "outputId": "2a16435b-e92a-480a-f9fb-e6effc42c4c2" }, "outputs": [ { @@ -156,11 +147,11 @@ "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "Array((), dtype=key) overlaying:\n", - "[1915779057 2258748098]\n" + "[2411773124 4124888837]\n" ] } ], @@ -185,9 +176,9 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 7, "metadata": { - "outputId": "1cfd9632-11cf-43f4-b922-63038705a195" + "outputId": "985b0f62-dfde-4f0f-fad4-a31927fc9f59" }, "outputs": [ { @@ -195,11 +186,11 @@ "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", - "[1830439201 4095528436]\n", + "[3077990774 2166202870]\n", "Array((), dtype=key) overlaying:\n", - "[3737706588 1614077470]\n", + "[3825832496 2886313970]\n", "Array((), dtype=key) overlaying:\n", - "[2940838374 2782395343]\n" + "[ 791337683 1373966058]\n" ] } ], @@ -216,9 +207,9 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 8, "metadata": { - "outputId": "60dde8ff-08a9-4743-c7ec-c997aa0537b6" + "outputId": "7e8ce538-e380-4db9-db23-bc4a8da577da" }, "outputs": [ { @@ -226,17 +217,17 @@ "output_type": "stream", "text": [ "rng_stream1: Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", - "[1830439201 4095528436]\n", + "[3077990774 2166202870]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", - "[3737706588 1614077470]\n", + "[3825832496 2886313970]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", - "[1915779057 2258748098]\n", + "[2411773124 4124888837]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", - "[2940838374 2782395343]\n" + "[ 791337683 1373966058]\n" ] } ], @@ -272,9 +263,9 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 9, "metadata": { - "outputId": "12e8c3c5-a32c-46ab-ba71-98cc90aaed70" + "outputId": "b70be039-589a-48f7-dc54-65e78c449c65" }, "outputs": [ { @@ -282,17 +273,17 @@ "output_type": "stream", "text": [ "rng_stream1: Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", - "[1915779057 2258748098]\n", + "[2411773124 4124888837]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", - "[1915779057 2258748098]\n" + "[2411773124 4124888837]\n" ] } ], @@ -331,7 +322,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -358,9 +349,9 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 11, "metadata": { - "outputId": "6bf352c1-086a-4c7a-8ab9-92abde614270" + "outputId": "d26b7355-9e8b-4954-b2f4-cf7520d5c5a3" }, "outputs": [ { @@ -385,9 +376,9 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 12, "metadata": { - "outputId": "3809fddc-7319-4d55-ccc5-7b7c78160ea9" + "outputId": "dec627a6-4c5a-4e3e-ce11-ce4f72775261" }, "outputs": [ { @@ -395,11 +386,11 @@ "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "Array((), dtype=key) overlaying:\n", - "[1915779057 2258748098]\n" + "[2411773124 4124888837]\n" ] } ], @@ -425,9 +416,9 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 13, "metadata": { - "outputId": "cc88c770-80ff-4093-e37f-6dc10c24006f" + "outputId": "5b7a9ae9-ca49-4ac0-d007-5caeee739ff0" }, "outputs": [ { @@ -435,17 +426,17 @@ "output_type": "stream", "text": [ "RNGModule, count 1: Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "RNGModule, count 2: Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "RNGSubModule_0, count 1: Array((), dtype=key) overlaying:\n", - "[3619592043 626287670]\n", + "[3858825717 2323087578]\n", "RNGSubModule_0, count 2: Array((), dtype=key) overlaying:\n", - "[965377860 480622172]\n", + "[ 601859108 3782857444]\n", "RNGSubSubModule_0, count 1: Array((), dtype=key) overlaying:\n", - "[1015683150 3648653849]\n", + "[ 234240654 1028548813]\n", "RNGSubSubModule_0, count 2: Array((), dtype=key) overlaying:\n", - "[3694284925 2979568433]\n" + "[3650462303 2124609379]\n" ] } ], @@ -497,9 +488,9 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 14, "metadata": { - "outputId": "d38896e6-8061-4f54-fa30-680b3f524071" + "outputId": "c0de4d37-0f00-4e58-bdfd-e8a6454ed681" }, "outputs": [ { @@ -543,9 +534,9 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 15, "metadata": { - "outputId": "556a79c6-0b38-4736-f5ad-79fd1976b191" + "outputId": "0b77a038-7000-407b-c5b8-a28dea7951d1" }, "outputs": [ { @@ -553,17 +544,17 @@ "output_type": "stream", "text": [ "RNGModule, count 1: Array((), dtype=key) overlaying:\n", - "[1543086838 3704909070]\n", + "[1428664606 3351135085]\n", "RNGModule, count 2: Array((), dtype=key) overlaying:\n", - "[2702764981 3978623664]\n", + "[3456700291 3873160899]\n", "RNGSubModule_0, count 1: Array((), dtype=key) overlaying:\n", - "[3619592043 626287670]\n", + "[3858825717 2323087578]\n", "RNGSubModule_0, count 2: Array((), dtype=key) overlaying:\n", - "[965377860 480622172]\n", + "[ 601859108 3782857444]\n", "RNGSubModule_1, count 1: Array((), dtype=key) overlaying:\n", - "[2339259155 3791202376]\n", + "[ 426957352 2006350344]\n", "RNGSubModule_1, count 2: Array((), dtype=key) overlaying:\n", - "[3825254750 938582370]\n" + "[4006253729 4205356731]\n" ] } ], @@ -588,9 +579,9 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 16, "metadata": { - "outputId": "433f5e48-d379-490f-e499-ef4c24032776" + "outputId": "d189d25e-425d-4fd7-fe18-2dfd63f28b87" }, "outputs": [ { @@ -647,26 +638,26 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 17, "metadata": { - "outputId": "fc7df4d6-5d6e-4c3c-98f2-d5b0e9a8a45e" + "outputId": "a7816385-0e08-48e2-dc51-055d7bcd0bab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[[ 0.6227115 -0.0870162 ]\n", - " [-0.35369048 0.68549377]]\n", - "[[-1.7602162 -0.807241 ]\n", - " [-0.09234211 0.79185605]]\n", - "[[-1.5997566 -0.09624068]\n", - " [ 0.21725878 0.16356604]]\n", - "[[ 0.44956833 -1.1854612 ]\n", - " [ 0.17371362 -0.768862 ]]\n", - "[[ 1.5019631 -1.3983697 ]\n", - " [-0.48637655 -1.4504721 ]]\n", - "[ 2.114508 -0.5090715]\n" + "[[-1.6185919 0.700908 ]\n", + " [-1.3146383 -0.79342234]]\n", + "[[ 0.0761425 -1.6157459]\n", + " [-1.6857724 0.7126891]]\n", + "[[ 0.60175574 0.2553228 ]\n", + " [ 0.27367848 -2.1975214 ]]\n", + "[[1.6249592 0.30813068]\n", + " [1.6613585 1.0404155 ]]\n", + "[[ 0.0030665 0.29551846]\n", + " [ 0.16670242 -0.78252524]]\n", + "[1.582462 0.15216611]\n" ] } ], @@ -720,9 +711,9 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 18, "metadata": { - "outputId": "40745795-8bbc-4c31-811a-f3658e5459d7" + "outputId": "ccec9d64-9a27-47f7-adaf-b36a5ea655db" }, "outputs": [ { @@ -775,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -822,25 +813,25 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 20, "metadata": { - "outputId": "9e689bfd-ba73-4b78-c3c0-89693caada2d" + "outputId": "e9da8228-acba-403d-bcb5-33a39d4d530d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0.86758584\n", - "0.86374676\n", - "0.861765\n", - "0.85973275\n", - "0.8573686\n", - "0.85558856\n", - "0.8524709\n", - "0.84871066\n", - "0.8456431\n", - "0.84329045\n" + "2.518454\n", + "2.4859657\n", + "2.4171872\n", + "2.412684\n", + "2.3435805\n", + "2.2773488\n", + "2.2592616\n", + "2.2009292\n", + "2.1839895\n", + "2.1707344\n" ] } ], @@ -883,18 +874,18 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 21, "metadata": { - "outputId": "80e45c76-bb4d-48ab-8fac-8b14126428d1" + "outputId": "887142ff-c9ca-4aae-d9fa-cc9993d809c5" }, "outputs": [ { "data": { "text/plain": [ - "Array(False, dtype=bool)" + "Array(True, dtype=bool)" ] }, - "execution_count": 62, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -930,9 +921,9 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 22, "metadata": { - "outputId": "07fc8349-41a3-4c74-d25c-0ffc4aa1be0a" + "outputId": "001cbd49-129b-4474-c6a1-3255a4ee3dfe" }, "outputs": [ { @@ -962,9 +953,9 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 23, "metadata": { - "outputId": "4ba0c3cb-e903-40af-ab3f-687d83c257c9" + "outputId": "35a2b204-bdfd-4f83-8e98-ba723963cb0c" }, "outputs": [ { @@ -973,7 +964,7 @@ "Array(False, dtype=bool)" ] }, - "execution_count": 64, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1016,9 +1007,9 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 24, "metadata": { - "outputId": "9deab9d8-3e15-4be0-8d91-279f6984ee99" + "outputId": "6c280522-4b43-4b82-f40a-b73986659b2c" }, "outputs": [ { @@ -1046,9 +1037,9 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 25, "metadata": { - "outputId": "9b47eef2-612e-4d9f-ffb2-1e868cb52d86" + "outputId": "d1bfbcad-e28a-4fae-8136-98bd1efb9332" }, "outputs": [ { @@ -1064,7 +1055,7 @@ " [ 0.17373264]], dtype=float32)" ] }, - "execution_count": 66, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -1111,9 +1102,9 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 26, "metadata": { - "outputId": "f6596bdc-89b4-46bf-ba99-84b1009d8156" + "outputId": "b672b85f-7a2d-44b5-afc1-bbf9426655ed" }, "outputs": [ { @@ -1129,7 +1120,7 @@ " Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.17373264]])]" ] }, - "execution_count": 67, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1147,33 +1138,53 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 27, "metadata": { - "outputId": "7df87ccd-0979-4b85-fbac-1d31aac53276" + "outputId": "1c0a16ce-fa3f-4b95-d794-58464bbaa9ae" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┐\n", - "│ CPU 0 │\n", - "├───────┤\n", - "│ CPU 1 │\n", - "├───────┤\n", - "│ CPU 2 │\n", - "├───────┤\n", - "│ CPU 3 │\n", - "├───────┤\n", - "│ CPU 4 │\n", - "├───────┤\n", - "│ CPU 5 │\n", - "├───────┤\n", - "│ CPU 6 │\n", - "├───────┤\n", - "│ CPU 7 │\n", - "└───────┘\n" - ] + "data": { + "text/html": [ + "
  CPU 0  \n",
+       "         \n",
+       "  CPU 1  \n",
+       "         \n",
+       "  CPU 2  \n",
+       "         \n",
+       "  CPU 3  \n",
+       "         \n",
+       "  CPU 4  \n",
+       "         \n",
+       "  CPU 5  \n",
+       "         \n",
+       "  CPU 6  \n",
+       "         \n",
+       "  CPU 7  \n",
+       "         \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1189,29 +1200,32 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 28, "metadata": { - "outputId": "cb7eb7d0-3fa7-4f72-a129-18901dc61bb1" + "outputId": "fe9ec875-3e7f-4861-babc-f07064737276" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-1.2839764]\n", - " [-1.2839764]\n", - " [-1.2839764]\n", - " [-1.2839764]\n", - " [-1.2839764]\n", - " [-1.2839764]\n", - " [-1.2839764]\n", - " [-1.2839764]]\n" - ] + "data": { + "text/plain": [ + "Array([[-1.2839764],\n", + " [-1.2839764],\n", + " [-1.2839764],\n", + " [-1.2839764],\n", + " [-1.2839764],\n", + " [-1.2839764],\n", + " [-1.2839764],\n", + " [-1.2839764]], dtype=float32)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "out = jit_forward(variables, x, False, apply_rng)\n", - "print(out)" + "out" ] }, { @@ -1223,9 +1237,9 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 29, "metadata": { - "outputId": "b88834c1-ebb2-422a-a665-1732015a3974" + "outputId": "0a9e5f2c-d4bf-4051-bf71-f32a9c32dc06" }, "outputs": [ { @@ -1241,7 +1255,7 @@ " [ 0.17373264]], dtype=float32)" ] }, - "execution_count": 70, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1253,9 +1267,9 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 30, "metadata": { - "outputId": "5b06f58b-7f1f-4cd6-c29f-3da5c0ea189f" + "outputId": "772a5063-1bd5-46b4-f6f6-cae9b4b81a26" }, "outputs": [ { @@ -1271,7 +1285,7 @@ " [-1.2839764]], dtype=float32)" ] }, - "execution_count": 71, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1301,9 +1315,9 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 31, "metadata": { - "outputId": "29b2911f-28a7-4a25-fcf7-4c3544dafada" + "outputId": "aa00b9a3-24ba-4048-ed8c-afbb9070f039" }, "outputs": [ { @@ -1319,7 +1333,7 @@ " [ 0.9023453 ]], dtype=float32)" ] }, - "execution_count": 72, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -1360,9 +1374,9 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 32, "metadata": { - "outputId": "2d74dff9-3a7d-47df-afbc-3bda4b8edf51" + "outputId": "e304289b-ef1c-4e4a-d4c1-4c41613bfa62" }, "outputs": [ { @@ -1378,7 +1392,7 @@ " Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.9023453]])]" ] }, - "execution_count": 73, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1389,33 +1403,53 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 33, "metadata": { - "outputId": "28bfe59b-ded1-455a-c894-f1c457ec28bf" + "outputId": "52fdb6d2-4c4f-44b3-feee-4bc5363c8f2f" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┐\n", - "│ CPU 0 │\n", - "├───────┤\n", - "│ CPU 1 │\n", - "├───────┤\n", - "│ CPU 2 │\n", - "├───────┤\n", - "│ CPU 3 │\n", - "├───────┤\n", - "│ CPU 4 │\n", - "├───────┤\n", - "│ CPU 5 │\n", - "├───────┤\n", - "│ CPU 6 │\n", - "├───────┤\n", - "│ CPU 7 │\n", - "└───────┘\n" - ] + "data": { + "text/html": [ + "
  CPU 0  \n",
+       "         \n",
+       "  CPU 1  \n",
+       "         \n",
+       "  CPU 2  \n",
+       "         \n",
+       "  CPU 3  \n",
+       "         \n",
+       "  CPU 4  \n",
+       "         \n",
+       "  CPU 5  \n",
+       "         \n",
+       "  CPU 6  \n",
+       "         \n",
+       "  CPU 7  \n",
+       "         \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", + "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1454,9 +1488,9 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 34, "metadata": { - "outputId": "7d7c4c3c-2ed2-40dd-f8d3-bb1ef91d93eb" + "outputId": "f0830f6b-659c-446f-c933-7b2a430f8004" }, "outputs": [ { @@ -1467,7 +1501,7 @@ " [-0.07084481, 0.60130936]], dtype=float32)}}" ] }, - "execution_count": 75, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1493,9 +1527,9 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 35, "metadata": { - "outputId": "93152c8f-ada9-4cf5-c0e2-560a284b3981" + "outputId": "eef5c0ca-f8d5-4f25-8ce6-9f2f60622daf" }, "outputs": [ { @@ -1514,7 +1548,7 @@ " [-0.07084481, 0.60130936]]], dtype=float32)}}" ] }, - "execution_count": 76, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1538,9 +1572,9 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 36, "metadata": { - "outputId": "dfbf4e49-55dc-4a6f-bece-396cb48d875e" + "outputId": "275699c3-ba48-403e-877d-07b65981cff5" }, "outputs": [ { @@ -1559,7 +1593,7 @@ " [-1.33515 , 0.5067442 ]]], dtype=float32)}}" ] }, - "execution_count": 77, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1583,18 +1617,15 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 37, "metadata": { - "outputId": "20fa39a3-9e7a-4a58-8ad7-48a0119ec466" + "outputId": "c11a80bc-d865-4e2e-e059-4d6bcea79e09" }, "outputs": [ { "data": { "text/plain": [ - "{'other_collection': {'kernel': Array([[-0.8193048 , 0.711106 ],\n", - " [-0.37802765, -0.66705877],\n", - " [-0.44808003, 0.93031347]], dtype=float32)},\n", - " 'params': {'Dense_0': {'bias': Array([[0., 0.],\n", + "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.9079084 , 0.76390624],\n", @@ -1604,10 +1635,13 @@ " [ 1.5344163 , 1.3186085 ]],\n", " \n", " [[-0.44171348, 0.43430036],\n", - " [-0.40732604, 0.29774475]]], dtype=float32)}}}" + " [-0.40732604, 0.29774475]]], dtype=float32)}},\n", + " 'other_collection': {'kernel': Array([[-0.8193048 , 0.711106 ],\n", + " [-0.37802765, -0.66705877],\n", + " [-0.44808003, 0.93031347]], dtype=float32)}}" ] }, - "execution_count": 78, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1644,18 +1678,15 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 38, "metadata": { - "outputId": "69d52ec5-9018-43c8-9d77-65e9bf00b79f" + "outputId": "fb16619c-c975-497d-c867-6fd5143b4507" }, "outputs": [ { "data": { "text/plain": [ - "{'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],\n", - " [ 0.44956833, -1.1854612 ],\n", - " [ 0.44956833, -1.1854612 ]], dtype=float32)},\n", - " 'params': {'Dense_0': {'bias': Array([[0., 0.],\n", + "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.9079084 , 0.76390624],\n", @@ -1665,10 +1696,13 @@ " [ 1.5344163 , 1.3186085 ]],\n", " \n", " [[-0.44171348, 0.43430036],\n", - " [-0.40732604, 0.29774475]]], dtype=float32)}}}" + " [-0.40732604, 0.29774475]]], dtype=float32)}},\n", + " 'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],\n", + " [ 0.44956833, -1.1854612 ],\n", + " [ 0.44956833, -1.1854612 ]], dtype=float32)}}" ] }, - "execution_count": 79, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -1692,16 +1726,15 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 39, "metadata": { - "outputId": "b01187f9-aded-41bc-d17f-e2965355675a" + "outputId": "f3a17d59-6f75-4408-caba-5769d4589263" }, "outputs": [ { "data": { "text/plain": [ - "{'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)},\n", - " 'params': {'Dense_0': {'bias': Array([[0., 0.],\n", + "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.9079084 , 0.76390624],\n", @@ -1711,10 +1744,11 @@ " [ 1.5344163 , 1.3186085 ]],\n", " \n", " [[-0.44171348, 0.43430036],\n", - " [-0.40732604, 0.29774475]]], dtype=float32)}}}" + " [-0.40732604, 0.29774475]]], dtype=float32)}},\n", + " 'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)}}" ] }, - "execution_count": 80, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -1745,9 +1779,9 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 40, "metadata": { - "outputId": "fda568b4-3c1d-4a5e-96f0-058c6fc5b49a" + "outputId": "29d1863b-809f-42ce-894c-1b0810faa41e" }, "outputs": [ { @@ -1766,7 +1800,7 @@ " [-0.15721127, -0.62520015]]], dtype=float32)}}}" ] }, - "execution_count": 81, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -1798,9 +1832,9 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 41, "metadata": { - "outputId": "620aff62-e20d-4794-b3af-5a5058c2471d" + "outputId": "6a825bcd-9c3b-43c2-afd2-42500d89fb26" }, "outputs": [ { @@ -1819,7 +1853,7 @@ " [ 0.9867164 , 0.75408363]]], dtype=float32)}}}" ] }, - "execution_count": 82, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/guides/flax_fundamentals/rng_guide.md b/docs/guides/flax_fundamentals/rng_guide.md index 5bdba84475..1030658031 100644 --- a/docs/guides/flax_fundamentals/rng_guide.md +++ b/docs/guides/flax_fundamentals/rng_guide.md @@ -29,7 +29,6 @@ Install or upgrade Flax, and then import some necessary dependencies. **Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don’t need this if you are already using a multi-device Google Cloud TPU environment, for example, on Google Cloud or in a Kaggle VM with a TPU. ```{code-cell} -:outputId: bb13d0ba-f908-4746-e4d3-f5916e0670ff :tags: [skip-execution] !pip install -q flax @@ -51,7 +50,7 @@ import hashlib ``` ```{code-cell} -:outputId: ab5e0d3b-2d51-44ee-b823-d152ad1a10b2 +:outputId: ec904f6b-0e87-4efe-87c4-fea0f8e8ec23 jax.devices() ``` @@ -73,7 +72,7 @@ The primary method Flax uses to receive, manipulate and create PRNG keys is via Note that this method can only be called with bounded modules (see [The Flax Module lifecycle](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html#top-level-modules)). ```{code-cell} -:outputId: 8304fbc2-ab0a-42f0-9d4c-c88a74012e83 +:outputId: 2a16435b-e92a-480a-f9fb-e6effc42c4c2 class RNGModule(nn.Module): @nn.compact @@ -89,7 +88,7 @@ variables = rng_module.init({'rng_stream': jax.random.key(0)}) Now if we use a different starting seed PRNG key, we will generate different values (as intended). ```{code-cell} -:outputId: 1cfd9632-11cf-43f4-b922-63038705a195 +:outputId: 985b0f62-dfde-4f0f-fad4-a31927fc9f59 variables = rng_module.init({'rng_stream': jax.random.key(1)}) ``` @@ -97,7 +96,7 @@ variables = rng_module.init({'rng_stream': jax.random.key(1)}) Calling `self.make_rng` for one stream will not affect the random values generated from another stream; i.e. the call order doesn't matter. ```{code-cell} -:outputId: 60dde8ff-08a9-4743-c7ec-c997aa0537b6 +:outputId: 7e8ce538-e380-4db9-db23-bc4a8da577da class RNGModuleTwoStreams(nn.Module): @nn.compact @@ -124,7 +123,7 @@ variables = rng_module_two_streams.init( Providing the same seed PRNG key will result in the same values being generated (provided that the same operations are used for those keys). ```{code-cell} -:outputId: 12e8c3c5-a32c-46ab-ba71-98cc90aaed70 +:outputId: b70be039-589a-48f7-dc54-65e78c449c65 variables = rng_module_two_streams.init( {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)} @@ -165,7 +164,7 @@ def produce_hash(data): And now you can manually reproduce the PRNG keys generated from `self.make_rng`: ```{code-cell} -:outputId: 6bf352c1-086a-4c7a-8ab9-92abde614270 +:outputId: d26b7355-9e8b-4954-b2f4-cf7520d5c5a3 stream_seed = jax.random.key(0) for call_count in range(1, 4): @@ -174,7 +173,7 @@ for call_count in range(1, 4): ``` ```{code-cell} -:outputId: 3809fddc-7319-4d55-ccc5-7b7c78160ea9 +:outputId: dec627a6-4c5a-4e3e-ce11-ce4f72775261 variables = rng_module.init({'rng_stream': jax.random.key(0)}) ``` @@ -188,7 +187,7 @@ This section explores how `self.make_rng` (`flax.linen.Module.make_rng`) behaves Consider the following example: ```{code-cell} -:outputId: cc88c770-80ff-4093-e37f-6dc10c24006f +:outputId: 5b7a9ae9-ca49-4ac0-d007-5caeee739ff0 class RNGSubSubModule(nn.Module): def __call__(self): @@ -231,7 +230,7 @@ With this data, you can manually reproduce the PRNG keys generated from the `Mod For example: ```{code-cell} -:outputId: d38896e6-8061-4f54-fa30-680b3f524071 +:outputId: c0de4d37-0f00-4e58-bdfd-e8a6454ed681 stream_seed = jax.random.key(0) for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')): @@ -248,7 +247,7 @@ for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModul If the same sub-`Module` class is used multiple times, you can increment the suffix of the sub-`Module` name accordingly. For example: `RNGSubModule_0`, `RNGSubModule_1`, and so on. ```{code-cell} -:outputId: 556a79c6-0b38-4736-f5ad-79fd1976b191 +:outputId: 0b77a038-7000-407b-c5b8-a28dea7951d1 class RNGSubModule(nn.Module): @nn.compact @@ -269,7 +268,7 @@ variables = rng_module.init({'rng_stream': jax.random.key(0)}) ``` ```{code-cell} -:outputId: 433f5e48-d379-490f-e499-ef4c24032776 +:outputId: d189d25e-425d-4fd7-fe18-2dfd63f28b87 stream_seed = jax.random.key(0) for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)): @@ -296,7 +295,7 @@ There are a couple of differences between the two methods that the user should t Below is an example using both `self.param` and `self.variable`: ```{code-cell} -:outputId: fc7df4d6-5d6e-4c3c-98f2-d5b0e9a8a45e +:outputId: a7816385-0e08-48e2-dc51-055d7bcd0bab class Model(nn.Module): @nn.compact @@ -341,7 +340,7 @@ Remember: * each submodule has their own separate count for each rng stream; this is why each `Dense` layer has their own separate count for `self.make_rng('params')` and why `model_param` and `model_variable1` share the same count (since they are defined within the same top-level parent module) ```{code-cell} -:outputId: 40745795-8bbc-4c31-811a-f3658e5459d7 +:outputId: ccec9d64-9a27-47f7-adaf-b36a5ea655db params_seed = jax.random.key(0) other_seed = jax.random.key(1) @@ -407,7 +406,7 @@ variables = module.init(init_rngs, x, train=False) ``` ```{code-cell} -:outputId: 9e689bfd-ba73-4b78-c3c0-89693caada2d +:outputId: e9da8228-acba-403d-bcb5-33a39d4d530d def update(variables, rng): # we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training @@ -441,7 +440,7 @@ There is an edge case where the same value can be unintentionally generated. See the [Flax issue](https://github.com/google/flax/issues/2157) for more details. ```{code-cell} -:outputId: 80e45c76-bb4d-48ab-8fac-8b14126428d1 +:outputId: 887142ff-c9ca-4aae-d9fa-cc9993d809c5 class Leaf(nn.Module): def __call__(self, x): @@ -467,7 +466,7 @@ out1 == out2 # same output, despite having different submodule names This occurs because the hash function [concatenates strings together](https://docs.python.org/3/library/hashlib.html#hashlib.hash.update), so the data `('AB', 'C')` is equivalent to data `('A', 'BC')` when fed into the hash function, therefore producing the same hash int. ```{code-cell} -:outputId: 07fc8349-41a3-4c74-d25c-0ffc4aa1be0a +:outputId: 001cbd49-129b-4474-c6a1-3255a4ee3dfe print(produce_hash(data=('A', 'B', 'C', 1))) print(produce_hash(data=('AB', 'C', 1))) @@ -478,7 +477,7 @@ print(produce_hash(data=('ABC', 1))) To avoid this edge case, users can flip the `flax_fix_rng_separator` [configuration flag](https://flax.readthedocs.io/en/latest/api_reference/flax.config.html#flax.configurations.Config.flax_fix_rng_separator) to `True`. ```{code-cell} -:outputId: 4ba0c3cb-e903-40af-ab3f-687d83c257c9 +:outputId: 35a2b204-bdfd-4f83-8e98-ba723963cb0c flax.config.update('flax_fix_rng_separator', True) out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)}) @@ -502,7 +501,7 @@ When using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit For more details on training on multiple devices in Flax using `jax.jit`, see our [Scale up Flax Modules on multiple devices guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html#) and [lm1b example](https://github.com/google/flax/tree/main/examples/lm1b). ```{code-cell} -:outputId: 9deab9d8-3e15-4be0-8d91-279f6984ee99 +:outputId: 6c280522-4b43-4b82-f40a-b73986659b2c # Create a mesh and annotate the axis with a name. device_mesh = mesh_utils.create_device_mesh((8,)) @@ -516,7 +515,7 @@ print(data_sharding) ``` ```{code-cell} -:outputId: 9b47eef2-612e-4d9f-ffb2-1e868cb52d86 +:outputId: d1bfbcad-e28a-4fae-8136-98bd1efb9332 class Model(nn.Module): @nn.compact @@ -553,7 +552,7 @@ The output is different given the same input, meaning the RNG key was used to ad We can also confirm that the output is sharded across devices: ```{code-cell} -:outputId: f6596bdc-89b4-46bf-ba99-84b1009d8156 +:outputId: b672b85f-7a2d-44b5-afc1-bbf9426655ed out.addressable_shards ``` @@ -561,7 +560,7 @@ out.addressable_shards Another way to visualize the output sharding: ```{code-cell} -:outputId: 7df87ccd-0979-4b85-fbac-1d31aac53276 +:outputId: 1c0a16ce-fa3f-4b95-d794-58464bbaa9ae jax.debug.visualize_array_sharding(out) ``` @@ -569,23 +568,23 @@ jax.debug.visualize_array_sharding(out) If we choose not to add noise, then the output is the same across all batches (as expected, since the input is the same for all batches): ```{code-cell} -:outputId: cb7eb7d0-3fa7-4f72-a129-18901dc61bb1 +:outputId: fe9ec875-3e7f-4861-babc-f07064737276 out = jit_forward(variables, x, False, apply_rng) -print(out) +out ``` We can confirm the un-jitted function produces the same values, albeit unsharded (note there may be small numerical differences due to compiler optimizations from jitting): ```{code-cell} -:outputId: b88834c1-ebb2-422a-a665-1732015a3974 +:outputId: 0a9e5f2c-d4bf-4051-bf71-f32a9c32dc06 out = forward(variables, x, True, apply_rng) out ``` ```{code-cell} -:outputId: 5b06f58b-7f1f-4cd6-c29f-3da5c0ea189f +:outputId: 772a5063-1bd5-46b4-f6f6-cae9b4b81a26 out = forward(variables, x, False, apply_rng) out @@ -602,7 +601,7 @@ with a batch of 8 PRNG keys and 8 devices, each device will see a PRNG key batch * therefore to access the PRNG key itself, we need to index slice into it (see the example below) ```{code-cell} -:outputId: 29b2911f-28a7-4a25-fcf7-4c3544dafada +:outputId: aa00b9a3-24ba-4048-ed8c-afbb9070f039 def forward(variables, x, add_noise, rng_key_batch): # rng_key_batch is a batch of size 1 containing 1 PRNG key @@ -633,13 +632,13 @@ out Confirm that the output is sharded across devices: ```{code-cell} -:outputId: 2d74dff9-3a7d-47df-afbc-3bda4b8edf51 +:outputId: e304289b-ef1c-4e4a-d4c1-4c41613bfa62 out.addressable_shards ``` ```{code-cell} -:outputId: 28bfe59b-ded1-455a-c894-f1c457ec28bf +:outputId: 52fdb6d2-4c4f-44b3-feee-4bc5363c8f2f jax.debug.visualize_array_sharding(out) ``` @@ -661,7 +660,7 @@ Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/develope We can use [`nn.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.vmap.html) to create a batched `Dense` layer: ```{code-cell} -:outputId: 7d7c4c3c-2ed2-40dd-f8d3-bb1ef91d93eb +:outputId: f0830f6b-659c-446f-c933-7b2a430f8004 x = jnp.ones((3, 2)) @@ -677,7 +676,7 @@ BatchDense(2).init(jax.random.key(0), x) By denoting `variable_axes={'params': 0}'`, we vectorize the `params` Arrays on the first axis. However the parameter values generated are all identical to each other: ```{code-cell} -:outputId: 93152c8f-ada9-4cf5-c0e2-560a284b3981 +:outputId: eef5c0ca-f8d5-4f25-8ce6-9f2f60622daf BatchDense = nn.vmap( nn.Dense, @@ -691,7 +690,7 @@ BatchDense(2).init(jax.random.key(0), x) If we also make `split_rngs={'params': True}`, then the PRNG key we provide is split across the variable axis (in this case, the batch axis 0), and we can generate different parameters for each batch input: ```{code-cell} -:outputId: dfbf4e49-55dc-4a6f-bece-396cb48d875e +:outputId: 275699c3-ba48-403e-877d-07b65981cff5 BatchDense = nn.vmap( nn.Dense, @@ -705,7 +704,7 @@ BatchDense(2).init(jax.random.key(0), x) Adding a variable via `self.variable` is straightforward: ```{code-cell} -:outputId: 20fa39a3-9e7a-4a58-8ad7-48a0119ec466 +:outputId: c11a80bc-d865-4e2e-e059-4d6bcea79e09 class Model(nn.Module): @nn.compact @@ -732,7 +731,7 @@ BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x) We can control which RNG stream to split, for example, if we only wanted to split the `'params'` RNG stream, then the variables generated from `self.variable` will be the same for each batch input: ```{code-cell} -:outputId: 69d52ec5-9018-43c8-9d77-65e9bf00b79f +:outputId: fb16619c-c975-497d-c867-6fd5143b4507 BatchModel = nn.vmap( Model, @@ -746,7 +745,7 @@ BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x) We can also control which parameters / variables should be generated for each batch input, for example, if we only wanted `'params'` to generate separate parameters for each batch input: ```{code-cell} -:outputId: b01187f9-aded-41bc-d17f-e2965355675a +:outputId: f3a17d59-6f75-4408-caba-5769d4589263 BatchModel = nn.vmap( Model, @@ -764,7 +763,7 @@ BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x) We can use [`nn.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html) to create a scanned `Module` layer (this is useful for simplifying repetitively stacked submodules): ```{code-cell} -:outputId: fda568b4-3c1d-4a5e-96f0-058c6fc5b49a +:outputId: 29d1863b-809f-42ce-894c-1b0810faa41e x = jnp.ones((3, 2)) @@ -786,7 +785,7 @@ ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry Similar to before, we can control whether to split the RNG stream or not, for example, if we wanted all the stacked modules to be initialized to the same parameter values, we can pass in `split_rngs={'params': False}`: ```{code-cell} -:outputId: 620aff62-e20d-4794-b3af-5a5058c2471d +:outputId: 6a825bcd-9c3b-43c2-afd2-42500d89fb26 ScanMLP = nn.scan( ResidualMLPBlock, variable_axes={'params': 0},