From 7cffbfb3341665c7bc6ddf71bac32a30762db80e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Wed, 12 Jul 2023 13:47:00 +0200 Subject: [PATCH 1/3] fixed order of parameters in stateless_apply in JAX distributed example (#458) * added Jax distributed training exammple using a Keras model * fixed file formatting * fixed file formatting * the order of arguments in stateless_appply has changed. Fixed example. --- examples/demo_jax_distributed.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index 9aef7060a..68cf068d6 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -157,13 +157,7 @@ def make_model(): # data will be split along the batch axis data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh # naming axes of the sharded partition -data_sharding = NamedSharding( - data_mesh, - P( - "batch", - ), -) - +data_sharding = NamedSharding(data_mesh,P("batch",),) # all variables will be replicated on all devices var_mesh = Mesh(devices, axis_names=("_")) # in NamedSharding, axes that are not mentioned are replicated (all axes here) @@ -275,7 +269,7 @@ def train_step(train_state, x, y): ) trainable_variables, optimizer_variables = optimizer.stateless_apply( - grads, train_state.trainable_variables, train_state.optimizer_variables + train_state.optimizer_variables, grads, train_state.trainable_variables ) return loss_value, TrainingState( From acb2517ba1e6ed15cd934a424661f4e6552bab07 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Wed, 12 Jul 2023 13:47:28 -0700 Subject: [PATCH 2/3] Added JAX distributed training guide. (#464) --- guides/distributed_training_with_jax.py | 253 ++++++++++++++++++++++++ 1 file changed, 253 insertions(+) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index e69de29bb..97bc8cef9 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -0,0 +1,253 @@ +""" +Title: Multi-GPU distributed training with JAX +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/07/11 +Last modified: 2023/07/11 +Description: Guide to multi-GPU/TPU training for Keras models with JAX. +Accelerator: GPU or TPU +""" +""" +## Introduction + +There are generally two ways to distribute computation across multiple devices: + +**Data parallelism**, where a single model gets replicated on multiple devices or +multiple machines. Each of them processes different batches of data, then they merge +their results. There exist many variants of this setup, that differ in how the different +model replicas merge results, in whether they stay in sync at every batch or whether they +are more loosely coupled, etc. + +**Model parallelism**, where different parts of a single model run on different devices, +processing a single batch of data together. This works best with models that have a +naturally-parallel architecture, such as models that feature multiple branches. + +This guide focuses on data parallelism, in particular **synchronous data parallelism**, +where the different replicas of the model stay in sync after each batch they process. +Synchronicity keeps the model convergence behavior identical to what you would see for +single-device training. + +Specifically, this guide teaches you how to use `jax.sharding` APIs to train Keras +models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16) +installed on a single machine (single host, multi-device training). This is the +most common setup for researchers and small-scale industry workflows. +""" + +""" +## Setup + +Let's start by defining the function that creates the model that we will train, +and the function that creates the dataset we will train on (MNIST in this case). +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import jax +import numpy as np +import tensorflow as tf +import keras_core as keras + +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + + +def get_model(): + # Make a simple convnet with batch normalization and dropout. + inputs = keras.Input(shape=(28, 28, 1)) + x = keras.layers.Rescaling(1.0 / 255.0)(inputs) + x = keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.Conv2D( + filters=24, + kernel_size=6, + use_bias=False, + strides=2, + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + strides=2, + name="large_k", + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.GlobalAveragePooling2D()(x) + x = keras.layers.Dense(256, activation="relu")(x) + x = keras.layers.Dropout(0.5)(x) + outputs = keras.layers.Dense(10)(x) + model = keras.Model(inputs, outputs) + return model + + +def get_datasets(): + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") + x_test = x_test.astype("float32") + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + print("x_train shape:", x_train.shape) + print(x_train.shape[0], "train samples") + print(x_test.shape[0], "test samples") + + # Create TF Datasets + train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_data, eval_data + + +""" +## Single-host, multi-device synchronous training + +In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16). +Each device will run a copy of your model (called a **replica**). For simplicity, in +what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality. + +**How it works** + +At each step of training: + +- The current batch of data (called **global batch**) is split into 8 different + sub-batches (called **local batches**). For instance, if the global batch has 512 + samples, each of the 8 local batches will have 64 samples. +- Each of the 8 replicas independently processes a local batch: they run a forward pass, + then a backward pass, outputting the gradient of the weights with respect to the loss of + the model on the local batch. +- The weight updates originating from local gradients are efficiently merged across the 8 + replicas. Because this is done at the end of every step, the replicas always stay in + sync. + +In practice, the process of synchronously updating the weights of the model replicas is +handled at the level of each individual weight variable. This is done through a using +a `jax.sharding.NamedSharding` that is configured to replicate the variables. + +**How to use it** + +To do single-host, multi-device synchronous training with a Keras model, you +would use the `jax.sharding` features. Here's how it works: + +- We first create a device mesh using `mesh_utils.create_device_mesh`. +- We use `jax.sharding.Mesh`, `jax.sharding.NamedSharding` and + `jax.sharding.PartitionSpec` to define how to partition JAX arrays. + - We specify that we want to replicate the model and optimizer variables + across all devices by using a spec with no axis. + - We specify that we want to shard the data across devices by using a spec + that splits along the batch dimension. +- We use `jax.device_put` to replicate the model and optimizer variables across + devices. This happens once at the beginning. +- In the training loop, for each batch that we process, we use `jax.device_put` + to split the batch across devices before invoking the train step. + +Here's the flow, where each step is split into its own utility function: +""" + +# Config +num_epochs = 2 +batch_size = 64 + +train_data, eval_data = get_datasets() +train_data = train_data.batch(batch_size, drop_remainder=True) + +model = get_model() +optimizer = keras.optimizers.Adam(1e-3) +loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +# Initialize all state with .build() +(one_batch, one_batch_labels) = next(iter(train_data)) +model.build(one_batch) +optimizer.build(model.trainable_variables) + + +# This is the loss function that will be differentiated. +# Keras provides a pure functional forward pass: model.stateless_call +def compute_loss(trainable_variables, non_trainable_variables, x, y): + y_pred, updated_non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x) + loss_value = loss(y, y_pred) + return loss_value, updated_non_trainable_variables + + +# Function to compute gradients +compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) + + +# Training step, Keras provides a pure functional optimizer.stateless_apply +@jax.jit +def train_step(train_state, x, y): + trainable_variables, non_trainable_variables, optimizer_variables = train_state + (loss_value, non_trainable_variables), grads = compute_gradients( + trainable_variables, non_trainable_variables, x, y + ) + + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + + return loss_value, (trainable_variables, non_trainable_variables, optimizer_variables) + + +# Replicate the model and optimizer variable on all devices +def get_replicated_train_state(devices): + # All variables will be replicated on all devices + var_mesh = Mesh(devices, axis_names=('_')) + # In NamedSharding, axes not mentioned are replicated (all axes here) + var_replication = NamedSharding(var_mesh, P()) + + # Apply the distribution settings to the model variables + trainable_variables = jax.device_put(model.trainable_variables, var_replication) + non_trainable_variables = jax.device_put(model.non_trainable_variables, var_replication) + optimizer_variables = jax.device_put(optimizer.variables, var_replication) + + # Combine all state in a tuple + return (trainable_variables, non_trainable_variables, optimizer_variables) + + +num_devices = len(jax.local_devices()) +print(f"Running on {num_devices} devices: {jax.local_devices()}") +devices = mesh_utils.create_device_mesh((num_devices,)) + +# Data will be split along the batch axis +data_mesh = Mesh(devices, axis_names=('batch',)) # naming axes of the mesh +data_sharding = NamedSharding(data_mesh, P('batch',)) # naming axes of the sharded partition + +# Display data sharding +x, y = next(iter(train_data)) +sharded_x = jax.device_put(x.numpy(), data_sharding) +print("Data sharding") +jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28*28])) + +train_state = get_replicated_train_state(devices) + +# Custom training loop +for epoch in range(num_epochs): + data_iter = iter(train_data) + for data in data_iter: + x, y = data + sharded_x = jax.device_put(x.numpy(), data_sharding) + loss_value, train_state = train_step(train_state, sharded_x, y.numpy()) + print("Epoch", epoch, "loss:", loss_value) + +# Post-processing model state update to write them back into the model +trainable_variables, non_trainable_variables, optimizer_variables = train_state +for variable, value in zip(model.trainable_variables, trainable_variables): + variable.assign(value) +for variable, value in zip( + model.non_trainable_variables, non_trainable_variables +): + variable.assign(value) + +""" +That's it! +""" From f66a337842098d34aeec05ec2be8672c3f36e4e0 Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Wed, 12 Jul 2023 14:19:11 -0700 Subject: [PATCH 3/3] optimize torch performance (#465) * optimize torch performance * fixing tests --------- Co-authored-by: Haifeng Jin --- keras_core/backend/torch/core.py | 47 +++++++++++++++----------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index ee84c7870..0aeaea145 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -11,7 +11,7 @@ from keras_core.backend.common.stateless_scope import StatelessScope DYNAMIC_SHAPES_OK = True - +DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPES = { "float16": torch.float16, @@ -39,20 +39,14 @@ def device_scope(device): global_state.set_global_attribute("torch_device", previous_device) -def get_default_device(): - return "cuda" if torch.cuda.is_available() else "cpu" - - def get_device(): device = global_state.get_global_attribute("torch_device", None) if device is None: - return get_default_device() + return DEFAULT_DEVICE return device def to_torch_dtype(dtype): - if dtype in [value for key, value in TORCH_DTYPES.items()]: - return dtype dtype = standardize_dtype(dtype) dtype = TORCH_DTYPES.get(dtype, None) if dtype is None: @@ -114,32 +108,32 @@ def __eq__(self, other): def convert_to_tensor(x, dtype=None): - dtype = to_torch_dtype(dtype or getattr(x, "dtype", None)) - device = get_device() - if isinstance(x, int): - dtype = torch.int32 - if isinstance(x, float): - dtype = torch.float32 + if is_tensor(x): + if dtype is None: + return x + return x.to(to_torch_dtype(dtype)) if isinstance(x, Variable): # TorchDynamo has bugs supporting nn.Parameter type check. # Return it directly instead of pass it to the rest of the logic in the # function. return x.value - if is_tensor(x): - if dtype and dtype != x.dtype: - x = x.to(dtype) - return x.to(device) - + if isinstance(x, int): + return torch.as_tensor(x, dtype=torch.int32, device=get_device()) + if isinstance(x, float): + return torch.as_tensor(x, dtype=torch.float32, device=get_device()) # Convert to np in case of any array-like that is not list or tuple. if not isinstance(x, (list, tuple)): x = np.array(x) elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x): # Handle list or tuple of torch tensors return torch.stack([convert_to_tensor(x1) for x1 in x]) - if isinstance(x, np.ndarray) and x.dtype == np.uint32: - # Torch backend does not support uint32. - x = x.astype(np.int64) - return torch.as_tensor(x, dtype=dtype, device=device) + if isinstance(x, np.ndarray): + if x.dtype == np.uint32: + # Torch backend does not support uint32. + x = x.astype(np.int64) + dtype = dtype or x.dtype + dtype = to_torch_dtype(dtype) + return torch.as_tensor(x, dtype=dtype, device=get_device()) def convert_to_numpy(x): @@ -170,7 +164,10 @@ def cast(x, dtype): if isinstance(x, KerasVariable): x = x.value if is_tensor(x): - return x.to(dtype) + if x.dtype == dtype: + return x + else: + return x.to(dtype) return convert_to_tensor(x, dtype) @@ -220,7 +217,7 @@ def symbolic_call(fn, args, kwargs, fill_value): ) return fn(*meta_args, **meta_kwargs) except: - with device_scope(get_default_device()): + with device_scope(DEFAULT_DEVICE): # If the `"meta"` device placement fails, fall back to tracing # eagerly with tensors on the default device. This will be # more robust, but more expensive.