From 823c6160aad4de723ba18c8cb2107ce8861301ad Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Mon, 1 Apr 2024 20:58:59 -0700 Subject: [PATCH] updated nnx mnist tutorial --- docs/experimental/nnx/mnist_tutorial.ipynb | 266 +++++++++++---------- docs/experimental/nnx/mnist_tutorial.md | 194 ++++++++------- 2 files changed, 236 insertions(+), 224 deletions(-) diff --git a/docs/experimental/nnx/mnist_tutorial.ipynb b/docs/experimental/nnx/mnist_tutorial.ipynb index e63b59b073..fc91fe50b9 100644 --- a/docs/experimental/nnx/mnist_tutorial.ipynb +++ b/docs/experimental/nnx/mnist_tutorial.ipynb @@ -59,9 +59,9 @@ "outputs": [], "source": [ "import tensorflow_datasets as tfds # TFDS for MNIST\n", - "import tensorflow as tf # TensorFlow operations\n", + "import tensorflow as tf # TensorFlow operations\n", "\n", - "tf.random.set_seed(0) # set random seed for reproducibility\n", + "tf.random.set_seed(0) # set random seed for reproducibility\n", "\n", "num_epochs = 10\n", "batch_size = 32\n", @@ -69,17 +69,27 @@ "train_ds: tf.data.Dataset = tfds.load('mnist', split='train')\n", "test_ds: tf.data.Dataset = tfds.load('mnist', split='test')\n", "\n", - "train_ds = train_ds.map(lambda sample: {\n", - " 'image': tf.cast(sample['image'],tf.float32) / 255,\n", - " 'label': sample['label']}) # normalize train set\n", - "test_ds = test_ds.map(lambda sample: {\n", - " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", - " 'label': sample['label']}) # normalize test set\n", - "\n", - "train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - "train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", - "test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency" + "train_ds = train_ds.map(\n", + " lambda sample: {\n", + " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", + " 'label': sample['label'],\n", + " }\n", + ") # normalize train set\n", + "test_ds = test_ds.map(\n", + " lambda sample: {\n", + " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", + " 'label': sample['label'],\n", + " }\n", + ") # normalize test set\n", + "\n", + "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "train_ds = train_ds.repeat(num_epochs).shuffle(1024)\n", + "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)\n", + "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "test_ds = test_ds.shuffle(1024)\n", + "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)" ] }, { @@ -117,8 +127,8 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x3559cb700>,\n", - " bias_init=,\n", + " kernel_init=.init at 0x35cbd31f0>,\n", + " bias_init=,\n", " conv_general_dilated=