Skip to content

Commit

Permalink
Merge pull request #251 from neuronets/ohinds-multi_gpu_warm_start
Browse files Browse the repository at this point in the history
Training from warm start with multiple GPUs
  • Loading branch information
satra authored Aug 18, 2023
2 parents 9c1a1e0 + 024139f commit dc5461c
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 36 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/guide-notebooks-ec2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ jobs:
set -xe
cd ${{ github.workspace }}
git clone https://github.com/neuronets/nobrainer-book.git
cd nobrainer-book
# if there is a matching book branch, switch to it
if [ $(git ls-remote --heads https://github.com/neuronets/nobrainer-book.git ${{ github.ref_name }} | wc -l) -ne 0 ]; then
echo "Checking out branch ${{ github.ref_name }}"
git checkout ${{ github.ref_name }};
else
echo "No matching branch found, sticking with the default"
fi
cd ${{ github.workspace }}
source env/bin/activate
for notebook_script in $(ls nobrainer-book/docs/nobrainer-guides/scripts/*.py); do
echo "running ${notebook_script}"
Expand Down
25 changes: 18 additions & 7 deletions nobrainer/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@
import tensorflow as tf


def get_strategy(multi_gpu):
if multi_gpu:
return tf.distribute.MirroredStrategy()
return tf.distribute.get_strategy()


class BaseEstimator:
"""Base class for all high-level models in Nobrainer."""

state_variables = []
model_ = None

def __init__(self, multi_gpu=False):
self.strategy = get_strategy(multi_gpu)

@property
def model(self):
return self.model_
Expand All @@ -25,6 +34,12 @@ def save(self, save_dir):
self.model_.save(save_dir)
model_info = {"classname": self.__class__.__name__, "__init__": {}}
for key in inspect.signature(self.__init__).parameters:
# TODO this assumes that all parameters passed to __init__
# are stored as members, which doesn't leave room for
# parameters that are specific to the runtime context.
# (e.g. multi_gpu).
if key == "multi_gpu":
continue
model_info["__init__"][key] = getattr(self, key)
for val in self.state_variables:
model_info[val] = getattr(self, val)
Expand All @@ -47,13 +62,9 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
del model_info["__init__"]
for key, value in model_info.items():
setattr(klass, key, value)
if multi_gpu:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
klass.model_ = tf.keras.models.load_model(
model_dir, custom_objects=custom_objects, compile=compile
)
else:
klass.strategy = get_strategy(multi_gpu)

with klass.strategy.scope():
klass.model_ = tf.keras.models.load_model(
model_dir, custom_objects=custom_objects, compile=compile
)
Expand Down
17 changes: 6 additions & 11 deletions nobrainer/processing/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(
dimensionality=3,
g_fmap_base=1024,
d_fmap_base=1024,
multi_gpu=False,
):
super().__init__(multi_gpu=multi_gpu)
self.model_ = None
self.latent_size = latent_size
self.label_size = label_size
Expand All @@ -45,7 +47,6 @@ def fit(
g_loss=losses.Wasserstein,
d_loss=losses.Wasserstein,
warm_start=False,
multi_gpu=False,
num_parallel_calls=None,
save_freq=500,
):
Expand Down Expand Up @@ -82,12 +83,6 @@ def fit(
d_opt_args_tmp.update(**d_opt_args)
d_opt_args = d_opt_args_tmp

# for multi gpu training
if multi_gpu:
strategy = tf.distribute.MirroredStrategy()
else:
strategy = tf.distribute.get_strategy()

if warm_start:
if self.model_ is None:
raise ValueError("warm_start requested, but model is undefined")
Expand All @@ -96,7 +91,7 @@ def fit(
from ..training import ProgressiveGANTrainer

# Instantiate the generator and discriminator
with strategy.scope():
with self.strategy.scope():
generator, discriminator = progressivegan(
latent_size=self.latent_size,
g_fmap_base=self.g_fmap_base,
Expand All @@ -112,7 +107,7 @@ def fit(
self.current_resolution_ = 0

# wrap the losses to work on multiple GPUs
with strategy.scope():
with self.strategy.scope():
d_loss_object = d_loss(reduction=tf.keras.losses.Reduction.NONE)

def compute_d_loss(labels, predictions):
Expand Down Expand Up @@ -149,7 +144,7 @@ def _compile():
continue
# create a train dataset with features for resolution
batch_size = info.get("batch_size")
if batch_size % strategy.num_replicas_in_sync:
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

dataset = get_dataset(
Expand All @@ -162,7 +157,7 @@ def _compile():
normalizer=info.get("normalizer") or normalizer,
)

with strategy.scope():
with self.strategy.scope():
# grow the networks by one (2^x) resolution
if resolution > self.current_resolution_:
self.model_.generator.add_resolution()
Expand Down
26 changes: 8 additions & 18 deletions nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class Segmentation(BaseEstimator):

state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(self, base_model, model_args=None):
def __init__(self, base_model, model_args=None, multi_gpu=False):
super().__init__(multi_gpu=multi_gpu)

if not isinstance(base_model, str):
self.base_model = base_model.__name__
else:
Expand All @@ -30,7 +32,6 @@ def fit(
dataset_validate=None,
epochs=1,
checkpoint_dir=os.getcwd(),
multi_gpu=False,
warm_start=False,
# TODO: figure out whether optimizer args should be flattened
optimizer=None,
Expand Down Expand Up @@ -75,26 +76,15 @@ def _compile():
if warm_start:
if self.model is None:
raise ValueError("warm_start requested, but model is undefined")
if multi_gpu:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
_compile()
else:
with self.strategy.scope():
_compile()
else:
mod = importlib.import_module("..models", "nobrainer.processing")
base_model = getattr(mod, self.base_model)
if multi_gpu:
strategy = tf.distribute.MirroredStrategy()
if batch_size % strategy.num_replicas_in_sync:
raise ValueError(
"batch size must be a multiple of the number of GPUs"
)

with strategy.scope():
_create(base_model)
_compile()
else:
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

with self.strategy.scope():
_create(base_model)
_compile()
print(self.model_.summary())
Expand Down

0 comments on commit dc5461c

Please sign in to comment.