Skip to content

Commit

Permalink
Merge branch 'main' of github.com:keras-team/keras-core
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jul 12, 2023
2 parents 409375d + f66a337 commit d85be71
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 33 deletions.
10 changes: 2 additions & 8 deletions examples/demo_jax_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,7 @@ def make_model():
# data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
# naming axes of the sharded partition
data_sharding = NamedSharding(
data_mesh,
P(
"batch",
),
)

data_sharding = NamedSharding(data_mesh,P("batch",),)
# all variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=("_"))
# in NamedSharding, axes that are not mentioned are replicated (all axes here)
Expand Down Expand Up @@ -275,7 +269,7 @@ def train_step(train_state, x, y):
)

trainable_variables, optimizer_variables = optimizer.stateless_apply(
grads, train_state.trainable_variables, train_state.optimizer_variables
train_state.optimizer_variables, grads, train_state.trainable_variables
)

return loss_value, TrainingState(
Expand Down
253 changes: 253 additions & 0 deletions guides/distributed_training_with_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""
Title: Multi-GPU distributed training with JAX
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2023/07/11
Last modified: 2023/07/11
Description: Guide to multi-GPU/TPU training for Keras models with JAX.
Accelerator: GPU or TPU
"""
"""
## Introduction
There are generally two ways to distribute computation across multiple devices:
**Data parallelism**, where a single model gets replicated on multiple devices or
multiple machines. Each of them processes different batches of data, then they merge
their results. There exist many variants of this setup, that differ in how the different
model replicas merge results, in whether they stay in sync at every batch or whether they
are more loosely coupled, etc.
**Model parallelism**, where different parts of a single model run on different devices,
processing a single batch of data together. This works best with models that have a
naturally-parallel architecture, such as models that feature multiple branches.
This guide focuses on data parallelism, in particular **synchronous data parallelism**,
where the different replicas of the model stay in sync after each batch they process.
Synchronicity keeps the model convergence behavior identical to what you would see for
single-device training.
Specifically, this guide teaches you how to use `jax.sharding` APIs to train Keras
models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16)
installed on a single machine (single host, multi-device training). This is the
most common setup for researchers and small-scale industry workflows.
"""

"""
## Setup
Let's start by defining the function that creates the model that we will train,
and the function that creates the dataset we will train on (MNIST in this case).
"""

import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import numpy as np
import tensorflow as tf
import keras_core as keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
filters=24,
kernel_size=6,
use_bias=False,
strides=2,
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
filters=32,
kernel_size=6,
padding="same",
strides=2,
name="large_k",
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
return model


def get_datasets():
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# Create TF Datasets
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
return train_data, eval_data


"""
## Single-host, multi-device synchronous training
In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16).
Each device will run a copy of your model (called a **replica**). For simplicity, in
what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.
**How it works**
At each step of training:
- The current batch of data (called **global batch**) is split into 8 different
sub-batches (called **local batches**). For instance, if the global batch has 512
samples, each of the 8 local batches will have 64 samples.
- Each of the 8 replicas independently processes a local batch: they run a forward pass,
then a backward pass, outputting the gradient of the weights with respect to the loss of
the model on the local batch.
- The weight updates originating from local gradients are efficiently merged across the 8
replicas. Because this is done at the end of every step, the replicas always stay in
sync.
In practice, the process of synchronously updating the weights of the model replicas is
handled at the level of each individual weight variable. This is done through a using
a `jax.sharding.NamedSharding` that is configured to replicate the variables.
**How to use it**
To do single-host, multi-device synchronous training with a Keras model, you
would use the `jax.sharding` features. Here's how it works:
- We first create a device mesh using `mesh_utils.create_device_mesh`.
- We use `jax.sharding.Mesh`, `jax.sharding.NamedSharding` and
`jax.sharding.PartitionSpec` to define how to partition JAX arrays.
- We specify that we want to replicate the model and optimizer variables
across all devices by using a spec with no axis.
- We specify that we want to shard the data across devices by using a spec
that splits along the batch dimension.
- We use `jax.device_put` to replicate the model and optimizer variables across
devices. This happens once at the beginning.
- In the training loop, for each batch that we process, we use `jax.device_put`
to split the batch across devices before invoking the train step.
Here's the flow, where each step is split into its own utility function:
"""

# Config
num_epochs = 2
batch_size = 64

train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)

model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)


# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
y_pred, updated_non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x)
loss_value = loss(y, y_pred)
return loss_value, updated_non_trainable_variables


# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)


# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
(loss_value, non_trainable_variables), grads = compute_gradients(
trainable_variables, non_trainable_variables, x, y
)

trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)

return loss_value, (trainable_variables, non_trainable_variables, optimizer_variables)


# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
# All variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=('_'))
# In NamedSharding, axes not mentioned are replicated (all axes here)
var_replication = NamedSharding(var_mesh, P())

# Apply the distribution settings to the model variables
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
non_trainable_variables = jax.device_put(model.non_trainable_variables, var_replication)
optimizer_variables = jax.device_put(optimizer.variables, var_replication)

# Combine all state in a tuple
return (trainable_variables, non_trainable_variables, optimizer_variables)


num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))

# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=('batch',)) # naming axes of the mesh
data_sharding = NamedSharding(data_mesh, P('batch',)) # naming axes of the sharded partition

# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28*28]))

train_state = get_replicated_train_state(devices)

# Custom training loop
for epoch in range(num_epochs):
data_iter = iter(train_data)
for data in data_iter:
x, y = data
sharded_x = jax.device_put(x.numpy(), data_sharding)
loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
print("Epoch", epoch, "loss:", loss_value)

# Post-processing model state update to write them back into the model
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(
model.non_trainable_variables, non_trainable_variables
):
variable.assign(value)

"""
That's it!
"""
47 changes: 22 additions & 25 deletions keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from keras_core.backend.common.stateless_scope import StatelessScope

DYNAMIC_SHAPES_OK = True

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TORCH_DTYPES = {
"float16": torch.float16,
Expand Down Expand Up @@ -39,20 +39,14 @@ def device_scope(device):
global_state.set_global_attribute("torch_device", previous_device)


def get_default_device():
return "cuda" if torch.cuda.is_available() else "cpu"


def get_device():
device = global_state.get_global_attribute("torch_device", None)
if device is None:
return get_default_device()
return DEFAULT_DEVICE
return device


def to_torch_dtype(dtype):
if dtype in [value for key, value in TORCH_DTYPES.items()]:
return dtype
dtype = standardize_dtype(dtype)
dtype = TORCH_DTYPES.get(dtype, None)
if dtype is None:
Expand Down Expand Up @@ -114,32 +108,32 @@ def __eq__(self, other):


def convert_to_tensor(x, dtype=None):
dtype = to_torch_dtype(dtype or getattr(x, "dtype", None))
device = get_device()
if isinstance(x, int):
dtype = torch.int32
if isinstance(x, float):
dtype = torch.float32
if is_tensor(x):
if dtype is None:
return x
return x.to(to_torch_dtype(dtype))
if isinstance(x, Variable):
# TorchDynamo has bugs supporting nn.Parameter type check.
# Return it directly instead of pass it to the rest of the logic in the
# function.
return x.value
if is_tensor(x):
if dtype and dtype != x.dtype:
x = x.to(dtype)
return x.to(device)

if isinstance(x, int):
return torch.as_tensor(x, dtype=torch.int32, device=get_device())
if isinstance(x, float):
return torch.as_tensor(x, dtype=torch.float32, device=get_device())
# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):
# Handle list or tuple of torch tensors
return torch.stack([convert_to_tensor(x1) for x1 in x])
if isinstance(x, np.ndarray) and x.dtype == np.uint32:
# Torch backend does not support uint32.
x = x.astype(np.int64)
return torch.as_tensor(x, dtype=dtype, device=device)
if isinstance(x, np.ndarray):
if x.dtype == np.uint32:
# Torch backend does not support uint32.
x = x.astype(np.int64)
dtype = dtype or x.dtype
dtype = to_torch_dtype(dtype)
return torch.as_tensor(x, dtype=dtype, device=get_device())


def convert_to_numpy(x):
Expand Down Expand Up @@ -170,7 +164,10 @@ def cast(x, dtype):
if isinstance(x, KerasVariable):
x = x.value
if is_tensor(x):
return x.to(dtype)
if x.dtype == dtype:
return x
else:
return x.to(dtype)
return convert_to_tensor(x, dtype)


Expand Down Expand Up @@ -220,7 +217,7 @@ def symbolic_call(fn, args, kwargs, fill_value):
)
return fn(*meta_args, **meta_kwargs)
except:
with device_scope(get_default_device()):
with device_scope(DEFAULT_DEVICE):
# If the `"meta"` device placement fails, fall back to tracing
# eagerly with tensors on the default device. This will be
# more robust, but more expensive.
Expand Down

0 comments on commit d85be71

Please sign in to comment.