diff --git a/docs/experimental/nnx/mnist_tutorial.ipynb b/docs/experimental/nnx/mnist_tutorial.ipynb index e63b59b073..fc91fe50b9 100644 --- a/docs/experimental/nnx/mnist_tutorial.ipynb +++ b/docs/experimental/nnx/mnist_tutorial.ipynb @@ -59,9 +59,9 @@ "outputs": [], "source": [ "import tensorflow_datasets as tfds # TFDS for MNIST\n", - "import tensorflow as tf # TensorFlow operations\n", + "import tensorflow as tf # TensorFlow operations\n", "\n", - "tf.random.set_seed(0) # set random seed for reproducibility\n", + "tf.random.set_seed(0) # set random seed for reproducibility\n", "\n", "num_epochs = 10\n", "batch_size = 32\n", @@ -69,17 +69,27 @@ "train_ds: tf.data.Dataset = tfds.load('mnist', split='train')\n", "test_ds: tf.data.Dataset = tfds.load('mnist', split='test')\n", "\n", - "train_ds = train_ds.map(lambda sample: {\n", - " 'image': tf.cast(sample['image'],tf.float32) / 255,\n", - " 'label': sample['label']}) # normalize train set\n", - "test_ds = test_ds.map(lambda sample: {\n", - " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", - " 'label': sample['label']}) # normalize test set\n", - "\n", - "train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - "train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", - "test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency" + "train_ds = train_ds.map(\n", + " lambda sample: {\n", + " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", + " 'label': sample['label'],\n", + " }\n", + ") # normalize train set\n", + "test_ds = test_ds.map(\n", + " lambda sample: {\n", + " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", + " 'label': sample['label'],\n", + " }\n", + ") # normalize test set\n", + "\n", + "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "train_ds = train_ds.repeat(num_epochs).shuffle(1024)\n", + "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)\n", + "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "test_ds = test_ds.shuffle(1024)\n", + "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)" ] }, { @@ -117,8 +127,8 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x3559cb700>,\n", - " bias_init=,\n", + " kernel_init=.init at 0x35cbd31f0>,\n", + " bias_init=,\n", " conv_general_dilated=