From 5e2c7a3741cf0014e38f3f2d21ec5bf230582a87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Wed, 21 Jun 2023 21:27:41 +0200 Subject: [PATCH 1/4] added Jax distributed training exammple using a Keras model --- examples/demo_jax_distributed.py | 268 +++++++++++++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 examples/demo_jax_distributed.py diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py new file mode 100644 index 000000000..b2c396062 --- /dev/null +++ b/examples/demo_jax_distributed.py @@ -0,0 +1,268 @@ +# To run this demo, you will need to spin up a "TPU VM" on Google Cloud. +# Please follow instructions here: https://cloud.google.com/tpu/docs/run-calculation-jax + +# Force a JAX backend +import os, pprint, collections +os.environ['KERAS_BACKEND'] = 'jax' + +pp = pprint.PrettyPrinter() + +import jax +import jax.numpy as jnp +import tensorflow as tf # just for tf.data +import keras_core as keras # Keras multi-backend + +import numpy as np +from tqdm import tqdm + +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +""" Dataset +Classic MNIST, loaded using tf.data +""" + +BATCH_SIZE=192 + +(x_train, train_labels), (x_eval, eval_labels) = keras.datasets.mnist.load_data() +x_train = np.expand_dims(x_train, axis=-1).astype(np.float32) # from 28x28 to 28x28 x 1 color channel (B&W) +x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32) + +train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) +train_data = train_data.shuffle(5000, reshuffle_each_iteration=True) +train_data = train_data.batch(BATCH_SIZE, drop_remainder=True) +train_data = train_data.repeat() + +eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels)) +eval_data = eval_data.batch(10000) # everything as one batch + +STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE + +""" Keras model +Simple but non-trivial model with: +* Batch Normalization (non-trainable state updated during trainig, different training-time and inference behavior) +* Dropout (randomness, different training time and inference behavior) +""" + +# Keras "sequential" model building style +def make_backbone(): + return keras.Sequential([ + keras.layers.Rescaling(1./255.), # input images are in the range [0, 255] + + keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', use_bias=False), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation('relu'), + + keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', use_bias=False, strides=2), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation('relu'), + + keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', use_bias=False, strides=2, name='large_k'), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation('relu'), + ], name="backbone") + +def make_model(): + input = keras.Input(shape=[28, 28, 1]) + y = make_backbone()(input) + y = keras.layers.Flatten()(y) + y = keras.layers.Dense(200, activation="relu")(y) + y = keras.layers.Dropout(0.4)(y) + y = keras.layers.Dense(10, activation='softmax')(y) + model = keras.Model(inputs=input, outputs=y) + return model + +""" JAX-native distribution with a Keras model +For now, you have to write a custom training loop for this +Note: The features required by jax.sharding are not supported by the Colab TPU +runtime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs. +""" + +if len(jax.local_devices()) < 8: + raise Exception("This part requires 8 devices to run") +else: + print("\nIdentified local devices:") + pp.pprint(jax.local_devices()) + +# ----------------- Keras --------------------- + +# instantiate the model +model = make_model() + +# learning rate +lr = keras.optimizers.schedules.ExponentialDecay(0.01, STEPS_PER_EPOCH, 0.6) + +# optimizer +optimizer = keras.optimizers.Adam(lr) + +# initialize all state with .build() +(one_batch, one_batch_labels) = next(iter(train_data)) +model.build(one_batch) +optimizer.build(model.trainable_variables) + +""" Distribution settings + +* Sharding the data on the batch axis +* Replicating all model variables + +Note: this implements standard "data parallel" distributed training + +* Just for show, sharding the largest convolutional kernel along the + "channels" axis 4-ways and replicating 2-ways + +Note: this does not reflect a best practice but is intended to show + that you can split a very large kernel across multiple devices + if you have to +""" + +print("\nMostly data-parallel distribution. " + "Data is sharded across devices while the model is replicated. " + "For demo purposes, we split the largest kernel 4-ways " + "(and replicate 2-ways since we have 8 devices).") + +# ------------------ Jax ---------------------- + +devices = mesh_utils.create_device_mesh((8,)) + +# data will be split along the batch axis +data_mesh = Mesh(devices, axis_names=('batch',)) # naming axes of the mesh +data_sharding = NamedSharding(data_mesh, P('batch',)) # naming axes of the sharded partition + +# all variables will be replicated on all devices +var_mesh = Mesh(devices, axis_names=('_')) +var_replication = NamedSharding(var_mesh, P()) # in NamedSharding, axes that are not mentioned are replicated (all axes here) + +# for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) +large_kernel_mesh = Mesh(devices.reshape((-1,4)), axis_names=(None, 'out_chan')) # naming axes of the mesh +large_kernel_sharding = NamedSharding(large_kernel_mesh, P(None, None, None, 'out_chan')) # naming axes of the sharded partition + +# ----------------- Keras --------------------- + +# Use Keras APIs to find the variable of a specific layer (we will be sharding this one in a special way) +# In a Conv2D or Dense layer, the variables are 'kernel' and 'bias' +special_layer_var = model.get_layer("backbone").get_layer("large_k").kernel + +# ------------------ Jax ---------------------- +# - accessing variables in Keras lists model.trainable_variables, +# - model.non_trainable_variables and optimizer.variables + +# Apply the distribution settings to the model variables +non_trainable_variables = jax.device_put(model.non_trainable_variables, var_replication) +optimizer_variables = jax.device_put(optimizer.variables, var_replication) +# this is what you would do replicate all trainable variables: +# trainable_variables = jax.device_put(model.trainable_variables, var_replication) + +# For the demo, we split the largest kernel 4-ways instead of replicating it. +# We still replicate all other trainable variables as in standard "data-parallel" +# distributed training. +print_once=True +trainable_variables = model.trainable_variables +for i,v in enumerate(trainable_variables): + if v is special_layer_var: + + # Apply distribution settings: sharding + sharded_v = jax.device_put(v, large_kernel_sharding) + trainable_variables[i] = sharded_v + + print("Sharding of convolutional", v.name, v.shape) + jax.debug.visualize_array_sharding(jnp.reshape(sharded_v, [-1, v.shape[-1]])) + else: + # Apply distribution settings: replication + replicated_v = jax.device_put(v, var_replication) + trainable_variables[i] = replicated_v + + if (print_once): + print_once=False + print("\nSharding of all other model variables (they are replicated)") + jax.debug.visualize_array_sharding(jnp.reshape(replicated_v, [-1, v.shape[-1]])) + +# collect state in a handy named tuple +TrainingState = collections.namedtuple('TrainingState', + ['trainable_variables', 'non_trainable_variables', 'optimizer_variables']) +device_train_state = TrainingState(trainable_variables=trainable_variables, + non_trainable_variables=non_trainable_variables, + optimizer_variables=optimizer_variables) +# display data sharding +x,y = next(iter(train_data)) +sharded_x = jax.device_put(x.numpy(), data_sharding) +print("Data sharding") +jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28*28])) + +# ------------------ Jax ---------------------- +# - Using Keras-provided stateless APIs +# - model.stateless_call +# - optimizer.stateless_apply +# These functions also work on other backends. + +# define loss +loss = keras.losses.SparseCategoricalCrossentropy() + +# This is the loss function that will be differentiated. +# Keras provides a pure functional forward pass: model.stateless_call +def compute_loss(trainable_variables, non_trainable_variables, x, y): + y_pred, updated_non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x) + loss_value = loss(y, y_pred) + return loss_value, updated_non_trainable_variables + +# function to compute gradients +compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) + +# Trainig step: Keras provides a pure functional optimizer.stateless_apply +@jax.jit +def train_step(train_state, x, y): + (loss_value, non_trainable_variables), grads = compute_gradients( + train_state.trainable_variables, train_state.non_trainable_variables, + x, y) + + trainable_variables, optimizer_variables = optimizer.stateless_apply( + grads, + train_state.trainable_variables, train_state.optimizer_variables) + + return loss_value, TrainingState(trainable_variables, + non_trainable_variables, + optimizer_variables) + +# training loop +EPOCHS=5 +print("\nTrainig:") +data_iter = iter(train_data) +for epoch in range(EPOCHS): + for i in tqdm(range(STEPS_PER_EPOCH)): + x, y = next(data_iter) + sharded_x = jax.device_put(x.numpy(), data_sharding) + loss_value, device_train_state = train_step(device_train_state, sharded_x, y.numpy()) + print("Epoch", epoch, "loss:", loss_value) + +# The output of the model is still sharded. Sharding follows the data. + +data, labels = next(iter(eval_data)) +sharded_data = jax.device_put(data.numpy(), data_sharding) + +@jax.jit +def predict(data): + predictions, updated_non_trainable_variables = model.stateless_call( + device_train_state.trainable_variables, + device_train_state.non_trainable_variables, data) + return predictions + +predictions = predict(sharded_data) +print("\nModel output sharding follows data sharding:") +jax.debug.visualize_array_sharding(predictions) + +# Post-processing model state update to write them back into the model +update = lambda variable, value: variable.assign(value) + +jax.tree_map(update, model.trainable_variables, device_train_state.trainable_variables) +jax.tree_map(update, model.non_trainable_variables, device_train_state.non_trainable_variables) +jax.tree_map(update, optimizer.variables, device_train_state.optimizer_variables) + +# check that the model has the new state by running an eval +# known issue: the optimizer should not be required here +model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), + metrics=[keras.metrics.SparseCategoricalAccuracy()]) +print("\nUpdating model and running an eval:") +loss, accuracy = model.evaluate(eval_data) +print("The model achieved an evaluation accuracy of:", accuracy) \ No newline at end of file From a7040a82385279f26e2b2a1f73a6f671a515a6b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Thu, 22 Jun 2023 17:58:18 +0200 Subject: [PATCH 2/4] fixed file formatting --- examples/demo_jax_distributed.py | 206 +++++++++++++++++++++---------- 1 file changed, 140 insertions(+), 66 deletions(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index b2c396062..d0e195fec 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -3,14 +3,15 @@ # Force a JAX backend import os, pprint, collections -os.environ['KERAS_BACKEND'] = 'jax' + +os.environ["KERAS_BACKEND"] = "jax" pp = pprint.PrettyPrinter() import jax import jax.numpy as jnp -import tensorflow as tf # just for tf.data -import keras_core as keras # Keras multi-backend +import tensorflow as tf # just for tf.data +import keras_core as keras # Keras multi-backend import numpy as np from tqdm import tqdm @@ -24,10 +25,15 @@ Classic MNIST, loaded using tf.data """ -BATCH_SIZE=192 +BATCH_SIZE = 192 -(x_train, train_labels), (x_eval, eval_labels) = keras.datasets.mnist.load_data() -x_train = np.expand_dims(x_train, axis=-1).astype(np.float32) # from 28x28 to 28x28 x 1 color channel (B&W) +(x_train, train_labels), ( + x_eval, + eval_labels, +) = keras.datasets.mnist.load_data() +x_train = np.expand_dims(x_train, axis=-1).astype( + np.float32 +) # from 28x28 to 28x28 x 1 color channel (B&W) x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32) train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) @@ -36,9 +42,9 @@ train_data = train_data.repeat() eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels)) -eval_data = eval_data.batch(10000) # everything as one batch +eval_data = eval_data.batch(10000) # everything as one batch -STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE +STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE """ Keras model Simple but non-trivial model with: @@ -46,23 +52,42 @@ * Dropout (randomness, different training time and inference behavior) """ + # Keras "sequential" model building style def make_backbone(): - return keras.Sequential([ - keras.layers.Rescaling(1./255.), # input images are in the range [0, 255] - - keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', use_bias=False), - keras.layers.BatchNormalization(scale=False, center=True), - keras.layers.Activation('relu'), + return keras.Sequential( + [ + keras.layers.Rescaling( + 1.0 / 255.0 + ), # input images are in the range [0, 255] + keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=24, + kernel_size=6, + padding="same", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + use_bias=False, + strides=2, + name="large_k", + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + ], + name="backbone", + ) - keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', use_bias=False, strides=2), - keras.layers.BatchNormalization(scale=False, center=True), - keras.layers.Activation('relu'), - - keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', use_bias=False, strides=2, name='large_k'), - keras.layers.BatchNormalization(scale=False, center=True), - keras.layers.Activation('relu'), - ], name="backbone") def make_model(): input = keras.Input(shape=[28, 28, 1]) @@ -70,10 +95,11 @@ def make_model(): y = keras.layers.Flatten()(y) y = keras.layers.Dense(200, activation="relu")(y) y = keras.layers.Dropout(0.4)(y) - y = keras.layers.Dense(10, activation='softmax')(y) + y = keras.layers.Dense(10, activation="softmax")(y) model = keras.Model(inputs=input, outputs=y) return model + """ JAX-native distribution with a Keras model For now, you have to write a custom training loop for this Note: The features required by jax.sharding are not supported by the Colab TPU @@ -117,26 +143,39 @@ def make_model(): if you have to """ -print("\nMostly data-parallel distribution. " - "Data is sharded across devices while the model is replicated. " - "For demo purposes, we split the largest kernel 4-ways " - "(and replicate 2-ways since we have 8 devices).") +print( + "\nMostly data-parallel distribution. " + "Data is sharded across devices while the model is replicated. " + "For demo purposes, we split the largest kernel 4-ways " + "(and replicate 2-ways since we have 8 devices)." +) # ------------------ Jax ---------------------- devices = mesh_utils.create_device_mesh((8,)) # data will be split along the batch axis -data_mesh = Mesh(devices, axis_names=('batch',)) # naming axes of the mesh -data_sharding = NamedSharding(data_mesh, P('batch',)) # naming axes of the sharded partition +data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh +data_sharding = NamedSharding( + data_mesh, + P( + "batch", + ), +) # naming axes of the sharded partition # all variables will be replicated on all devices -var_mesh = Mesh(devices, axis_names=('_')) -var_replication = NamedSharding(var_mesh, P()) # in NamedSharding, axes that are not mentioned are replicated (all axes here) +var_mesh = Mesh(devices, axis_names=("_")) +var_replication = NamedSharding( + var_mesh, P() +) # in NamedSharding, axes that are not mentioned are replicated (all axes here) # for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) -large_kernel_mesh = Mesh(devices.reshape((-1,4)), axis_names=(None, 'out_chan')) # naming axes of the mesh -large_kernel_sharding = NamedSharding(large_kernel_mesh, P(None, None, None, 'out_chan')) # naming axes of the sharded partition +large_kernel_mesh = Mesh( + devices.reshape((-1, 4)), axis_names=(None, "out_chan") +) # naming axes of the mesh +large_kernel_sharding = NamedSharding( + large_kernel_mesh, P(None, None, None, "out_chan") +) # naming axes of the sharded partition # ----------------- Keras --------------------- @@ -149,7 +188,9 @@ def make_model(): # - model.non_trainable_variables and optimizer.variables # Apply the distribution settings to the model variables -non_trainable_variables = jax.device_put(model.non_trainable_variables, var_replication) +non_trainable_variables = jax.device_put( + model.non_trainable_variables, var_replication +) optimizer_variables = jax.device_put(optimizer.variables, var_replication) # this is what you would do replicate all trainable variables: # trainable_variables = jax.device_put(model.trainable_variables, var_replication) @@ -157,38 +198,47 @@ def make_model(): # For the demo, we split the largest kernel 4-ways instead of replicating it. # We still replicate all other trainable variables as in standard "data-parallel" # distributed training. -print_once=True +print_once = True trainable_variables = model.trainable_variables -for i,v in enumerate(trainable_variables): +for i, v in enumerate(trainable_variables): if v is special_layer_var: - # Apply distribution settings: sharding sharded_v = jax.device_put(v, large_kernel_sharding) trainable_variables[i] = sharded_v print("Sharding of convolutional", v.name, v.shape) - jax.debug.visualize_array_sharding(jnp.reshape(sharded_v, [-1, v.shape[-1]])) + jax.debug.visualize_array_sharding( + jnp.reshape(sharded_v, [-1, v.shape[-1]]) + ) else: # Apply distribution settings: replication replicated_v = jax.device_put(v, var_replication) trainable_variables[i] = replicated_v - if (print_once): - print_once=False - print("\nSharding of all other model variables (they are replicated)") - jax.debug.visualize_array_sharding(jnp.reshape(replicated_v, [-1, v.shape[-1]])) + if print_once: + print_once = False + print( + "\nSharding of all other model variables (they are replicated)" + ) + jax.debug.visualize_array_sharding( + jnp.reshape(replicated_v, [-1, v.shape[-1]]) + ) # collect state in a handy named tuple -TrainingState = collections.namedtuple('TrainingState', - ['trainable_variables', 'non_trainable_variables', 'optimizer_variables']) -device_train_state = TrainingState(trainable_variables=trainable_variables, - non_trainable_variables=non_trainable_variables, - optimizer_variables=optimizer_variables) +TrainingState = collections.namedtuple( + "TrainingState", + ["trainable_variables", "non_trainable_variables", "optimizer_variables"], +) +device_train_state = TrainingState( + trainable_variables=trainable_variables, + non_trainable_variables=non_trainable_variables, + optimizer_variables=optimizer_variables, +) # display data sharding -x,y = next(iter(train_data)) +x, y = next(iter(train_data)) sharded_x = jax.device_put(x.numpy(), data_sharding) print("Data sharding") -jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28*28])) +jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28 * 28])) # ------------------ Jax ---------------------- # - Using Keras-provided stateless APIs @@ -199,41 +249,51 @@ def make_model(): # define loss loss = keras.losses.SparseCategoricalCrossentropy() + # This is the loss function that will be differentiated. # Keras provides a pure functional forward pass: model.stateless_call def compute_loss(trainable_variables, non_trainable_variables, x, y): y_pred, updated_non_trainable_variables = model.stateless_call( - trainable_variables, non_trainable_variables, x) + trainable_variables, non_trainable_variables, x + ) loss_value = loss(y, y_pred) return loss_value, updated_non_trainable_variables + # function to compute gradients compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) + # Trainig step: Keras provides a pure functional optimizer.stateless_apply @jax.jit def train_step(train_state, x, y): (loss_value, non_trainable_variables), grads = compute_gradients( - train_state.trainable_variables, train_state.non_trainable_variables, - x, y) + train_state.trainable_variables, + train_state.non_trainable_variables, + x, + y, + ) trainable_variables, optimizer_variables = optimizer.stateless_apply( - grads, - train_state.trainable_variables, train_state.optimizer_variables) + grads, train_state.trainable_variables, train_state.optimizer_variables + ) + + return loss_value, TrainingState( + trainable_variables, non_trainable_variables, optimizer_variables + ) - return loss_value, TrainingState(trainable_variables, - non_trainable_variables, - optimizer_variables) # training loop -EPOCHS=5 +EPOCHS = 5 print("\nTrainig:") data_iter = iter(train_data) for epoch in range(EPOCHS): for i in tqdm(range(STEPS_PER_EPOCH)): x, y = next(data_iter) sharded_x = jax.device_put(x.numpy(), data_sharding) - loss_value, device_train_state = train_step(device_train_state, sharded_x, y.numpy()) + loss_value, device_train_state = train_step( + device_train_state, sharded_x, y.numpy() + ) print("Epoch", epoch, "loss:", loss_value) # The output of the model is still sharded. Sharding follows the data. @@ -241,13 +301,17 @@ def train_step(train_state, x, y): data, labels = next(iter(eval_data)) sharded_data = jax.device_put(data.numpy(), data_sharding) + @jax.jit def predict(data): predictions, updated_non_trainable_variables = model.stateless_call( device_train_state.trainable_variables, - device_train_state.non_trainable_variables, data) + device_train_state.non_trainable_variables, + data, + ) return predictions + predictions = predict(sharded_data) print("\nModel output sharding follows data sharding:") jax.debug.visualize_array_sharding(predictions) @@ -255,14 +319,24 @@ def predict(data): # Post-processing model state update to write them back into the model update = lambda variable, value: variable.assign(value) -jax.tree_map(update, model.trainable_variables, device_train_state.trainable_variables) -jax.tree_map(update, model.non_trainable_variables, device_train_state.non_trainable_variables) -jax.tree_map(update, optimizer.variables, device_train_state.optimizer_variables) +jax.tree_map( + update, model.trainable_variables, device_train_state.trainable_variables +) +jax.tree_map( + update, + model.non_trainable_variables, + device_train_state.non_trainable_variables, +) +jax.tree_map( + update, optimizer.variables, device_train_state.optimizer_variables +) # check that the model has the new state by running an eval # known issue: the optimizer should not be required here -model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), - metrics=[keras.metrics.SparseCategoricalAccuracy()]) +model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(), + metrics=[keras.metrics.SparseCategoricalAccuracy()], +) print("\nUpdating model and running an eval:") loss, accuracy = model.evaluate(eval_data) -print("The model achieved an evaluation accuracy of:", accuracy) \ No newline at end of file +print("The model achieved an evaluation accuracy of:", accuracy) From 266fb0eaeee5fe9db7716e3b4a1159aac46ac6bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Thu, 22 Jun 2023 18:03:31 +0200 Subject: [PATCH 3/4] fixed file formatting --- examples/demo_jax_distributed.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index d0e195fec..328490f0a 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -156,18 +156,13 @@ def make_model(): # data will be split along the batch axis data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh -data_sharding = NamedSharding( - data_mesh, - P( - "batch", - ), -) # naming axes of the sharded partition +# naming axes of the sharded partition +data_sharding = NamedSharding(data_mesh,P("batch",),) # all variables will be replicated on all devices var_mesh = Mesh(devices, axis_names=("_")) -var_replication = NamedSharding( - var_mesh, P() -) # in NamedSharding, axes that are not mentioned are replicated (all axes here) +# in NamedSharding, axes that are not mentioned are replicated (all axes here) +var_replication = NamedSharding(var_mesh, P()) # for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) large_kernel_mesh = Mesh( From ded4269e289e6968adc0118c9f9e3862b008ceff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Wed, 12 Jul 2023 13:35:53 +0200 Subject: [PATCH 4/4] the order of arguments in stateless_appply has changed. Fixed example. --- examples/demo_jax_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index 328490f0a..9aa8f45c7 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -270,7 +270,7 @@ def train_step(train_state, x, y): ) trainable_variables, optimizer_variables = optimizer.stateless_apply( - grads, train_state.trainable_variables, train_state.optimizer_variables + train_state.optimizer_variables, grads, train_state.trainable_variables ) return loss_value, TrainingState(