From f7e9e55f98633de86b0be52000ed5ba108f18e42 Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Wed, 12 Oct 2016 16:24:06 +1100 Subject: [PATCH 1/8] Add domain adaptation example and gradient reversal layer --- examples/README.md | 3 + examples/dann.py | 297 +++++++++++++++++++++++++++ keras/backend/tensorflow_backend.py | 22 +- keras/backend/theano_backend.py | 34 ++- keras/datasets/mnist_m.py | 20 ++ keras/engine/training.py | 8 +- keras/layers/core.py | 29 +++ tests/keras/backend/test_backends.py | 23 +++ 8 files changed, 431 insertions(+), 5 deletions(-) create mode 100644 examples/dann.py create mode 100644 keras/datasets/mnist_m.py diff --git a/examples/README.md b/examples/README.md index 92be33a42eba..279f97f43e3c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -92,3 +92,6 @@ Demonstrates how to build a variational autoencoder. [variational_autoencoder_deconv.py](variational_autoencoder_deconv.py) Demonstrates how to build a variational autoencoder with Keras using deconvolution layers. + +[dann.py](dann.py) +Unsupervised Domain adaptation between source and target datasets to improve classification accuracy. diff --git a/examples/dann.py b/examples/dann.py new file mode 100644 index 000000000000..1e5b6faa8b88 --- /dev/null +++ b/examples/dann.py @@ -0,0 +1,297 @@ +''' +This is the Keras implementation of +'Domain-Adversarial Training of Neural Networks' by Y. Ganin + +This allows domain adaptation (when you want to train on a dataset +with different statistics than a target dataset) in an unsupervised manner +by using the adversarial paradigm to punish features that help discriminate +between the datasets during backpropagation. + +This is achieved by usage of the 'gradient reversal' layer to form +a domain invariant embedding for classification by an MLP. + +The example here uses the 'MNIST-M' dataset as described in the paper. + +Credits: +- Clayton Mellina (https://github.com/pumpikano/tf-dann) for providing + a sketch of implementation (in TF) and utility functions. +- Yusuke Iwasawa (https://github.com/fchollet/keras/issues/3119#issuecomment-230289301) + for Theano implementation (op) for gradient reversal. + +Author: Vanush Vaswani (vanush@gmail.com) +''' + +from __future__ import print_function +from keras.layers import Input, Dense, Dropout, Flatten, Lambda +from keras.layers import Convolution2D, MaxPooling2D +from keras.optimizers import SGD +from keras.models import Model +from keras.utils.visualize_util import plot +from keras.utils import np_utils +from keras.datasets import mnist +import keras.backend as K + +import numpy as np +from matplotlib import pyplot as plt +from mpl_toolkits.axes_grid1 import ImageGrid + +from sklearn.manifold import TSNE + +from keras.layers import GradientReversal +from keras.engine.training import make_batches +from keras.datasets import mnist_m + + +# Helper functions + +def imshow_grid(images, shape=[2, 8]): + """Plot images in a grid of a given shape.""" + fig = plt.figure() + grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05) + + size = shape[0] * shape[1] + for i in range(size): + grid[i].axis('off') + # The AxesGrid object work as a list of axes. + grid[i].imshow(np.swapaxes(np.swapaxes(images[i], 0, 2), 0, 1)) + + +def plot_embedding(X, y, d, title=None): + """Plot an embedding X with the class label y colored by the domain d.""" + x_min, x_max = np.min(X, 0), np.max(X, 0) + X = (X - x_min) / (x_max - x_min) + + # Plot colors numbers + plt.figure(figsize=(10, 10)) + plt.subplot(111) + for i in range(X.shape[0]): + # plot colored number + plt.text(X[i, 0], X[i, 1], str(y[i]), + color=plt.cm.bwr(d[i] / 1.), + fontdict={'weight': 'bold', 'size': 9}) + plt.xticks([]), plt.yticks([]) + if title is not None: + plt.title(title) + + +def batch_gen(batches, id_array, data, labels): + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = id_array[batch_start:batch_end] + if labels is not None: + yield data[batch_ids], labels[batch_ids] + else: + yield data[batch_ids] + np.random.shuffle(id_array) + +# Model parameters + +batch_size = 128 +nb_epoch = 15 +nb_classes = 10 +img_rows, img_cols = 28, 28 +nb_filters = 32 +nb_pool = 2 +nb_conv = 5 + +_TRAIN = K.variable(1, dtype='uint8') + +# Prep source data +(X_train, y_train), (X_test, y_test) = mnist.load_data() +y_train = np_utils.to_categorical(y_train, nb_classes) +y_test = np_utils.to_categorical(y_test, nb_classes) + +# Prep target data +mnistm = mnist_m.load_data() +XT_test = np.swapaxes(np.swapaxes(mnistm['test'], 1, 3), 2, 3) +XT_train = np.swapaxes(np.swapaxes(mnistm['train'], 1, 3), 2, 3) + +X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) +X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols) +X_train = np.concatenate([X_train, X_train, X_train], axis=1) +X_test = np.concatenate([X_test, X_test, X_test], axis=1) + +X_train = X_train.astype('float32') +X_test = X_test.astype('float32') +X_train /= 255 +X_test /= 255 + +XT_train = XT_train.astype('float32') +XT_test = XT_test.astype('float32') +XT_train /= 255 +XT_test /= 255 + +domain_labels = np.vstack([np.tile([0, 1], [batch_size / 2, 1]), + np.tile([1., 0.], [batch_size / 2, 1])]) + +# Created mixed dataset for TSNE visualization +num_test = 500 +combined_test_imgs = np.vstack([X_test[:num_test], XT_test[:num_test]]) +combined_test_labels = np.vstack([y_test[:num_test], y_test[:num_test]]) +combined_test_domain = np.vstack([np.tile([1., 0.], [num_test, 1]), + np.tile([0., 1.], [num_test, 1])]) + + +class DANNBuilder(object): + def __init__(self): + self.model = None + self.net = None + self.domain_invariant_features = None + self.grl = None + self.opt = SGD() + + def _build_feature_extractor(self, model_input): + '''Build segment of net for feature extraction.''' + net = Convolution2D(nb_filters, nb_conv, nb_conv, + border_mode='valid', + activation='relu')(model_input) + net = Convolution2D(nb_filters, nb_conv, nb_conv, + activation='relu')(net) + net = MaxPooling2D(pool_size=(nb_pool, nb_pool))(net) + net = Dropout(0.5)(net) + net = Flatten()(net) + self.domain_invariant_features = net + return net + + def _build_classifier(self, model_input): + net = Dense(128, activation='relu')(model_input) + net = Dropout(0.5)(net) + net = Dense(nb_classes, activation='softmax', + name='classifier_output')(net) + return net + + def build_source_model(self, main_input, plot_model=False): + net = self._build_feature_extractor(main_input) + net = self._build_classifier(net) + model = Model(input=main_input, output=net) + if plot_model: + plot(model, show_shapes=True) + model.compile(loss={'classifier_output': 'categorical_crossentropy'}, + optimizer=self.opt, metrics=['accuracy']) + return model + + def build_dann_model(self, main_input, plot_model=False): + net = self._build_feature_extractor(main_input) + self.grl = GradientReversal(1.0) + branch = self.grl(net) + branch = Dense(128, activation='relu')(branch) + branch = Dropout(0.1)(branch) + branch = Dense(2, activation='softmax', name='domain_output')(branch) + + # When building DANN model, route first half of batch (source examples) + # to domain classifier, and route full batch (half source, half target) + # to the domain classifier. + net = Lambda(lambda x: K.switch(K.learning_phase(), + x[:int(batch_size / 2), :], x, lazy=True), + output_shape=lambda x: ((batch_size / 2,) + + x[1:]) if _TRAIN else x[0:])(net) + + net = self._build_classifier(net) + model = Model(input=main_input, output=[branch, net]) + if plot_model: + plot(model, show_shapes=True) + model.compile(loss={'classifier_output': 'categorical_crossentropy', + 'domain_output': 'categorical_crossentropy'}, + optimizer=self.opt, metrics=['accuracy']) + return model + + def build_tsne_model(self, main_input): + '''Create model to output intermediate layer + activations to visualize domain invariant features''' + tsne_model = Model(input=main_input, + output=self.domain_invariant_features) + return tsne_model + + +main_input = Input(shape=(3, img_rows, img_cols), name='main_input') + +builder = DANNBuilder() +src_model = builder.build_source_model(main_input) +src_vis = builder.build_tsne_model(main_input) + +dann_model = builder.build_dann_model(main_input) +dann_vis = builder.build_tsne_model(main_input) +print('Training source only model') +src_model.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=1, + validation_data=(X_test, y_test)) +print('Evaluating target samples on source-only model') +print('Accuracy: ', src_model.evaluate(XT_test, y_test)[1]) + +# Broken out training loop for a DANN model. +src_index_arr = np.arange(X_train.shape[0]) +target_index_arr = np.arange(XT_train.shape[0]) + +batches_per_epoch = len(X_train) / batch_size +num_steps = nb_epoch * batches_per_epoch +j = 0 + +print('Training DANN model') + +for i in range(nb_epoch): + + batches = make_batches(X_train.shape[0], batch_size / 2) + target_batches = make_batches(XT_train.shape[0], batch_size / 2) + + src_gen = batch_gen(batches, src_index_arr, X_train, y_train) + target_gen = batch_gen(target_batches, target_index_arr, XT_train, None) + + losses = list() + acc = list() + + print('Epoch ', i) + + for (xb, yb) in src_gen: + + # Update learning rate and gradient multiplier as described in + # the paper. + p = float(j) / num_steps + l = 2. / (1. + np.exp(-10. * p)) - 1 + lr = 0.01 / (1. + 10 * p)**0.75 + builder.grl.l = l + builder.opt.lr = lr + + if xb.shape[0] != batch_size / 2: + continue + + try: + xt = target_gen.next() + except: + # Regeneration + target_gen = target_gen(target_batches, target_index_arr, XT_train, + None) + + # Concatenate source and target batch + xb = np.vstack([xb, xt]) + + metrics = dann_model.train_on_batch({'main_input': xb}, + {'classifier_output': yb, + 'domain_output': domain_labels}, + check_batch_dim=False) + j += 1 + +print('Evaluating target samples on DANN model') +out = dann_model.predict_on_batch(XT_test[0:batch_size / 2]) +out = np.argmax(out[1], axis=1) +actual = np.argmax(y_test[0:batch_size / 2], axis=1) +acc = float(np.sum((out == actual))) / float(len(out)) +print('Accuracy: ', acc) +print('Visualizing output of domain invariant features') + +# Plot both MNIST and MNIST-M +imshow_grid(X_train) +imshow_grid(XT_train) + +src_embedding = src_vis.predict([combined_test_imgs]) +src_tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000) +tsne = src_tsne.fit_transform(src_embedding) + +plot_embedding(tsne, combined_test_labels.argmax(1), + combined_test_domain.argmax(1), 'Source only') + +dann_embedding = dann_vis.predict([combined_test_imgs]) +dann_tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000) +tsne = dann_tsne.fit_transform(dann_embedding) + +plot_embedding(tsne, combined_test_labels.argmax(1), + combined_test_domain.argmax(1), 'DANN') + +plt.show() diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index e68829398518..90906515afbb 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1310,7 +1310,7 @@ def _cond(condition, then_lambda, else_lambda): return cond_fn(condition, then_lambda, else_lambda) -def switch(condition, then_expression, else_expression): +def switch(condition, then_expression, else_expression, lazy=False): '''Switches between two operations depending on a scalar value (int or bool). Note that both `then_expression` and `else_expression` should be symbolic tensors of the *same shape*. @@ -1319,6 +1319,7 @@ def switch(condition, then_expression, else_expression): condition: scalar tensor. then_expression: TensorFlow operation. else_expression: TensorFlow operation. + lazy: Unused (compatibility with Theano backend) ''' x_shape = copy.copy(then_expression.get_shape()) x = _cond(tf.cast(condition, 'bool'), @@ -1924,3 +1925,22 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, for st in decoded] return (decoded_dense, log_prob) + + +class ReverseGradientBuilder(object): + '''Flips the sign of incoming gradients in training''' + def __init__(self): + self.num_calls = 0 + + def __call__(self, x, l=1.0): + grad_name = "GradientReversal%d" % self.num_calls + @tf.python.framework.ops.RegisterGradient(grad_name) + def _flip_gradients(op, grad): + return [tf.neg(grad) * l] + + g = get_session().graph + with g.gradient_override_map({'Identity': grad_name}): + y = tf.identity(x) + + self.num_calls += 1 + return y diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 30f03adbddfe..26558a95c66e 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -963,9 +963,15 @@ def _step(input, *states): return last_output, outputs, states -def switch(condition, then_expression, else_expression): +def switch(condition, then_expression, else_expression, lazy=False): '''condition: scalar tensor. + + # Arguments: + lazy: Use ifelse op which evaluates arguments in a lazy manner. ''' + if lazy: + return theano.ifelse.ifelse(condition, then_expression, else_expression) + return T.switch(condition, then_expression, else_expression) @@ -1686,3 +1692,29 @@ def ctc_step(y_true_step, y_pred_step, input_length_step, label_length_step): ret = ret.dimshuffle('x', 0) return ret + + +class ReverseGradient(theano.Op): + '''Flips the sign of the gradient during training.''' + view_map = {0: [0]} + __props__ = ('l', ) + def __init__(self, l): + super(ReverseGradient, self).__init__() + self.l = l + + def make_node(self, x): + assert (hasattr(self, '_props'), + 'Your version of theano is too old to support __props__.') + x = T.as_tensor_variable(x) + return theano.Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + xin, = inputs + xout, = output_storage + xout[0] = xin + + def grad(self, input, output_gradients): + return [-self.l * output_gradients[0]] + + def infer_shape(self, node, i0_shapes): + return i0_shapes diff --git a/keras/datasets/mnist_m.py b/keras/datasets/mnist_m.py new file mode 100644 index 000000000000..4b0d42c01507 --- /dev/null +++ b/keras/datasets/mnist_m.py @@ -0,0 +1,20 @@ +import gzip +from ..utils.data_utils import get_file +from six.moves import cPickle +import sys + +def load_data(path='keras_mnistm.pkl.gz'): + path = get_file(path, origin='https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz') + + if path.endswith('.gz'): + f = gzip.open(path, 'rb') + else: + f = open(path, 'rb') + + if sys.version_info < (3,): + data = cPickle.load(f) + else: + data = cPickle.load(f, encoding='bytes') + + f.close() + return data diff --git a/keras/engine/training.py b/keras/engine/training.py index 4b50455aafec..cf9bfaaffdc7 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -983,7 +983,8 @@ def _standardize_user_data(self, x, y, sample_weights = [standardize_weights(ref, sw, cw, mode) for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, self.sample_weight_modes)] - check_array_lengths(x, y, sample_weights) + if check_batch_dim: + check_array_lengths(x, y, sample_weights) check_loss_and_target_compatibility(y, self.loss_functions, self.internal_output_shapes) if self.stateful and batch_size: if x[0].shape[0] % batch_size != 0: @@ -1192,7 +1193,7 @@ def predict(self, x, batch_size=32, verbose=0): batch_size=batch_size, verbose=verbose) def train_on_batch(self, x, y, - sample_weight=None, class_weight=None): + sample_weight=None, class_weight=None, check_batch_dim=True): '''Runs a single gradient update on a single batch of data. # Arguments @@ -1215,6 +1216,7 @@ def train_on_batch(self, x, y, from this class during training. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. + check_batch_dim: Whether to check batch dimensions for sanity. # Returns Scalar training loss (if the model has a single output and no metrics) @@ -1225,7 +1227,7 @@ def train_on_batch(self, x, y, x, y, sample_weights = self._standardize_user_data(x, y, sample_weight=sample_weight, class_weight=class_weight, - check_batch_dim=True) + check_batch_dim=check_batch_dim) if self.uses_learning_phase and type(K.learning_phase()) is not int: ins = x + y + sample_weights + [1.] else: diff --git a/keras/layers/core.py b/keras/layers/core.py index 1311dbc4d2f3..7eb9878fdf68 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -1213,3 +1213,32 @@ def get_config(self): 'input_length': self.input_length} base_config = super(TimeDistributedDense, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +class GradientReversal(Layer): + def __init__(self, l, **kwargs): + super(GradientReversal, self).__init__(**kwargs) + self.l = K.variable(l) + self.supports_masking = False + + if K._BACKEND == 'theano': + self.op = K.ReverseGradient(self.l) + elif K._BACKEND == 'tensorflow': + self.op = K.ReverseGradientBuilder() + + def build(self, input_shape): + self.trainable_weights = [] + + def call(self, x, mask=None): + if K._BACKEND == 'theano': + return self.op(x) + elif K._BACKEND == 'tensorflow': + return self.op(x, self.l) + + def get_output_shape_for(self, input_shape): + return input_shape + + def get_config(self): + config = {'lambda' : self.l} + base_config = super(GradientReversal, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index cc9bf422f0c2..27fbc0214bed 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -249,6 +249,29 @@ def test_gradient(self): assert_allclose(zero_zth, zth, atol=1e-05) assert_allclose(zero_ztf, ztf, atol=1e-05) + def test_gradient_reversal(self): + val = np.random.random((4, 2)) + lth = KTH.variable(0.5) + ltf = KTF.variable(0.5) + + xth = KTH.variable(val) + xtf = KTF.variable(val) + + yth = xth ** 2 + ytf = xtf ** 2 + + lossth = KTH.ReverseGradient(lth)(KTH.sum(yth)) + losstf = KTF.ReverseGradientBuilder()(KTF.sum(ytf), ltf) + + gradth = KTH.gradients(lossth, [xth]) + gradtf = KTF.gradients(losstf, [xtf]) + zth = KTH.eval(gradth[0]) + ztf = KTF.eval(gradtf[0]) + + assert_allclose(zth, -2 * 0.5 * val, atol=1e-05) + assert_allclose(ztf, -2 * 0.5 * val, atol=1e-05) + assert_allclose(zth, ztf, atol=1e-05) + def test_function(self): val = np.random.random((4, 2)) input_val = np.random.random((4, 2)) From bcc06eb511cc78d95b4498b982faa0a24210e841 Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Wed, 12 Oct 2016 16:54:20 +1100 Subject: [PATCH 2/8] Fix pep8 issues --- keras/backend/tensorflow_backend.py | 1 + keras/backend/theano_backend.py | 8 ++++---- keras/datasets/mnist_m.py | 2 +- keras/layers/core.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 90906515afbb..f06735abf031 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1934,6 +1934,7 @@ def __init__(self): def __call__(self, x, l=1.0): grad_name = "GradientReversal%d" % self.num_calls + @tf.python.framework.ops.RegisterGradient(grad_name) def _flip_gradients(op, grad): return [tf.neg(grad) * l] diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 26558a95c66e..76987247a5cb 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -966,8 +966,8 @@ def _step(input, *states): def switch(condition, then_expression, else_expression, lazy=False): '''condition: scalar tensor. - # Arguments: - lazy: Use ifelse op which evaluates arguments in a lazy manner. + # Arguments: + lazy: Use ifelse op which evaluates arguments in a lazy manner. ''' if lazy: return theano.ifelse.ifelse(condition, then_expression, else_expression) @@ -1698,13 +1698,13 @@ class ReverseGradient(theano.Op): '''Flips the sign of the gradient during training.''' view_map = {0: [0]} __props__ = ('l', ) + def __init__(self, l): super(ReverseGradient, self).__init__() self.l = l def make_node(self, x): - assert (hasattr(self, '_props'), - 'Your version of theano is too old to support __props__.') + assert (hasattr(self, '_props'), 'Your version of theano is too old to support __props__.') x = T.as_tensor_variable(x) return theano.Apply(self, [x], [x.type()]) diff --git a/keras/datasets/mnist_m.py b/keras/datasets/mnist_m.py index 4b0d42c01507..03f936fb3659 100644 --- a/keras/datasets/mnist_m.py +++ b/keras/datasets/mnist_m.py @@ -11,7 +11,7 @@ def load_data(path='keras_mnistm.pkl.gz'): else: f = open(path, 'rb') - if sys.version_info < (3,): + if sys.version_info < (3,): data = cPickle.load(f) else: data = cPickle.load(f, encoding='bytes') diff --git a/keras/layers/core.py b/keras/layers/core.py index 7eb9878fdf68..d8730b18343e 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -1239,6 +1239,6 @@ def get_output_shape_for(self, input_shape): return input_shape def get_config(self): - config = {'lambda' : self.l} + config = {'lambda': self.l} base_config = super(GradientReversal, self).get_config() return dict(list(base_config.items()) + list(config.items())) From d53a4a6459abedb28138a34aa51be5b3f230010b Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Fri, 14 Oct 2016 20:21:24 +1100 Subject: [PATCH 3/8] Evaluate over entire test set --- examples/dann.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/dann.py b/examples/dann.py index 1e5b6faa8b88..f22d99db71e8 100644 --- a/examples/dann.py +++ b/examples/dann.py @@ -83,6 +83,17 @@ def batch_gen(batches, id_array, data, labels): yield data[batch_ids] np.random.shuffle(id_array) + +def evaluate_dann(num_batches, size): + acc = 0 + for i in range(0, num_batches): + _, prob = dann_model.predict_on_batch(XT_test[i * size:i * size + size]) + predictions = np.argmax(prob, axis=1) + actual = np.argmax(y_test[i * size:i * size + size], axis=1) + acc += float(np.sum((predictions == actual))) / size + return acc / num_batches + + # Model parameters batch_size = 128 @@ -269,11 +280,10 @@ def build_tsne_model(self, main_input): j += 1 print('Evaluating target samples on DANN model') -out = dann_model.predict_on_batch(XT_test[0:batch_size / 2]) -out = np.argmax(out[1], axis=1) -actual = np.argmax(y_test[0:batch_size / 2], axis=1) -acc = float(np.sum((out == actual))) / float(len(out)) -print('Accuracy: ', acc) +size = batch_size / 2 +nb_testbatches = XT_test.shape[0] / size +acc = evaluate_dann(nb_testbatches, size) +print('Accuracy:', acc) print('Visualizing output of domain invariant features') # Plot both MNIST and MNIST-M From 262f9d00ec6e09bb2069aff7b00d2504e7c5bb4e Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Sat, 15 Oct 2016 12:29:20 +1100 Subject: [PATCH 4/8] Clean up backend implementation --- keras/backend/tensorflow_backend.py | 11 ++++++----- keras/backend/theano_backend.py | 12 ++++++------ keras/layers/core.py | 23 +++++++++++------------ tests/keras/backend/test_backends.py | 2 +- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index f06735abf031..cafd7588bdbf 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1927,17 +1927,18 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, return (decoded_dense, log_prob) -class ReverseGradientBuilder(object): - '''Flips the sign of incoming gradients in training''' - def __init__(self): +class ReverseGradient(object): + '''Flips the sign of incoming gradient during training.''' + def __init__(self, hp_lambda): self.num_calls = 0 + self.hp_lambda = hp_lambda - def __call__(self, x, l=1.0): + def __call__(self, x): grad_name = "GradientReversal%d" % self.num_calls @tf.python.framework.ops.RegisterGradient(grad_name) def _flip_gradients(op, grad): - return [tf.neg(grad) * l] + return [tf.neg(grad) * self.hp_lambda] g = get_session().graph with g.gradient_override_map({'Identity': grad_name}): diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 76987247a5cb..0e4f9ae0e130 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -1695,16 +1695,16 @@ def ctc_step(y_true_step, y_pred_step, input_length_step, label_length_step): class ReverseGradient(theano.Op): - '''Flips the sign of the gradient during training.''' + '''Flips the sign of incoming gradient during training.''' view_map = {0: [0]} - __props__ = ('l', ) + __props__ = ('hp_lambda', ) - def __init__(self, l): + def __init__(self, hp_lambda): super(ReverseGradient, self).__init__() - self.l = l + self.hp_lambda = hp_lambda def make_node(self, x): - assert (hasattr(self, '_props'), 'Your version of theano is too old to support __props__.') + assert hasattr(self, '_props'), 'Your version of theano is too old to support __props__.' x = T.as_tensor_variable(x) return theano.Apply(self, [x], [x.type()]) @@ -1714,7 +1714,7 @@ def perform(self, node, inputs, output_storage): xout[0] = xin def grad(self, input, output_gradients): - return [-self.l * output_gradients[0]] + return [-self.hp_lambda * output_gradients[0]] def infer_shape(self, node, i0_shapes): return i0_shapes diff --git a/keras/layers/core.py b/keras/layers/core.py index d8730b18343e..46091fb04439 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -1216,29 +1216,28 @@ def get_config(self): class GradientReversal(Layer): - def __init__(self, l, **kwargs): + ''' + Flip the sign of gradient during training. + + # Arguments: + hp_lambda: Scalar to multiply the flipped gradient. + ''' + def __init__(self, hp_lambda, **kwargs): super(GradientReversal, self).__init__(**kwargs) - self.l = K.variable(l) + self.hp_lambda = K.variable(hp_lambda) self.supports_masking = False - - if K._BACKEND == 'theano': - self.op = K.ReverseGradient(self.l) - elif K._BACKEND == 'tensorflow': - self.op = K.ReverseGradientBuilder() + self.op = K.ReverseGradient(self.hp_lambda) def build(self, input_shape): self.trainable_weights = [] def call(self, x, mask=None): - if K._BACKEND == 'theano': - return self.op(x) - elif K._BACKEND == 'tensorflow': - return self.op(x, self.l) + return self.op(x) def get_output_shape_for(self, input_shape): return input_shape def get_config(self): - config = {'lambda': self.l} + config = {'hp_lambda': self.hp_lambda} base_config = super(GradientReversal, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index 27fbc0214bed..30ad1f6f17c5 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -261,7 +261,7 @@ def test_gradient_reversal(self): ytf = xtf ** 2 lossth = KTH.ReverseGradient(lth)(KTH.sum(yth)) - losstf = KTF.ReverseGradientBuilder()(KTF.sum(ytf), ltf) + losstf = KTF.ReverseGradient(ltf)(KTF.sum(ytf)) gradth = KTH.gradients(lossth, [xth]) gradtf = KTF.gradients(losstf, [xtf]) From fd8883acf8fe603fdb221c8f7964441a49ac9f03 Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Wed, 26 Oct 2016 23:11:40 +1100 Subject: [PATCH 5/8] Simplify implementation so reverse gradient op is a function --- examples/dann.py | 28 +++++++++++++------------- keras/backend/tensorflow_backend.py | 30 +++++++++++++--------------- keras/backend/theano_backend.py | 10 +++++++--- keras/engine/training.py | 2 +- keras/layers/core.py | 10 ++++------ tests/keras/backend/test_backends.py | 4 ++-- 6 files changed, 42 insertions(+), 42 deletions(-) diff --git a/examples/dann.py b/examples/dann.py index f22d99db71e8..0ca8c2b17d1d 100644 --- a/examples/dann.py +++ b/examples/dann.py @@ -84,10 +84,13 @@ def batch_gen(batches, id_array, data, labels): np.random.shuffle(id_array) -def evaluate_dann(num_batches, size): +def evaluate_dann(X_test, batch_size): + """Predict batch by batch.""" + size = batch_size / 2 + num_batches = X_test.shape[0] / size acc = 0 for i in range(0, num_batches): - _, prob = dann_model.predict_on_batch(XT_test[i * size:i * size + size]) + _, prob = dann_model.predict_on_batch(X_test[i * size:i * size + size]) predictions = np.argmax(prob, axis=1) actual = np.argmax(y_test[i * size:i * size + size], axis=1) acc += float(np.sum((predictions == actual))) / size @@ -180,10 +183,10 @@ def build_source_model(self, main_input, plot_model=False): optimizer=self.opt, metrics=['accuracy']) return model - def build_dann_model(self, main_input, plot_model=False): + def build_dann_model(self, main_input, hp_lambda, plot_model=False): net = self._build_feature_extractor(main_input) - self.grl = GradientReversal(1.0) - branch = self.grl(net) + self.grl = GradientReversal() + branch = self.grl(net, hp_lambda) branch = Dense(128, activation='relu')(branch) branch = Dropout(0.1)(branch) branch = Dense(2, activation='softmax', name='domain_output')(branch) @@ -191,10 +194,8 @@ def build_dann_model(self, main_input, plot_model=False): # When building DANN model, route first half of batch (source examples) # to domain classifier, and route full batch (half source, half target) # to the domain classifier. - net = Lambda(lambda x: K.switch(K.learning_phase(), - x[:int(batch_size / 2), :], x, lazy=True), - output_shape=lambda x: ((batch_size / 2,) + - x[1:]) if _TRAIN else x[0:])(net) + net = Lambda(lambda x: K.switch(K.learning_phase(), x[:int(batch_size / 2), :], x, lazy=True), + output_shape=lambda x: ((batch_size / 2,) + x[1:]))(net) net = self._build_classifier(net) model = Model(input=main_input, output=[branch, net]) @@ -219,7 +220,8 @@ def build_tsne_model(self, main_input): src_model = builder.build_source_model(main_input) src_vis = builder.build_tsne_model(main_input) -dann_model = builder.build_dann_model(main_input) +hp_lambda = K.variable(1.0) +dann_model = builder.build_dann_model(main_input, hp_lambda) dann_vis = builder.build_tsne_model(main_input) print('Training source only model') src_model.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=1, @@ -257,7 +259,7 @@ def build_tsne_model(self, main_input): p = float(j) / num_steps l = 2. / (1. + np.exp(-10. * p)) - 1 lr = 0.01 / (1. + 10 * p)**0.75 - builder.grl.l = l + hp_lambda = l builder.opt.lr = lr if xb.shape[0] != batch_size / 2: @@ -280,9 +282,7 @@ def build_tsne_model(self, main_input): j += 1 print('Evaluating target samples on DANN model') -size = batch_size / 2 -nb_testbatches = XT_test.shape[0] / size -acc = evaluate_dann(nb_testbatches, size) +acc = evaluate_dann(XT_test, batch_size) print('Accuracy:', acc) print('Visualizing output of domain invariant features') diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index cafd7588bdbf..6ff2fbcd1c62 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1927,22 +1927,20 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, return (decoded_dense, log_prob) -class ReverseGradient(object): - '''Flips the sign of incoming gradient during training.''' - def __init__(self, hp_lambda): - self.num_calls = 0 - self.hp_lambda = hp_lambda - - def __call__(self, x): - grad_name = "GradientReversal%d" % self.num_calls +def reverse_gradient(X, hp_lambda): + '''Flips the sign of the incoming gradient during training.''' + try: + reverse_gradient.num_calls += 1 + except AttributeError: + reverse_gradient.num_calls = 1 - @tf.python.framework.ops.RegisterGradient(grad_name) - def _flip_gradients(op, grad): - return [tf.neg(grad) * self.hp_lambda] + grad_name = "GradientReversal%d" % reverse_gradient.num_calls + @tf.python.framework.ops.RegisterGradient(grad_name) + def _flip_gradients(op, grad): + return [tf.neg(grad) * hp_lambda] - g = get_session().graph - with g.gradient_override_map({'Identity': grad_name}): - y = tf.identity(x) + g = get_session().graph + with g.gradient_override_map({'Identity': grad_name}): + y = tf.identity(X) - self.num_calls += 1 - return y + return y diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 0e4f9ae0e130..a5e4c6925ed9 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -1696,12 +1696,11 @@ def ctc_step(y_true_step, y_pred_step, input_length_step, label_length_step): class ReverseGradient(theano.Op): '''Flips the sign of incoming gradient during training.''' - view_map = {0: [0]} __props__ = ('hp_lambda', ) - def __init__(self, hp_lambda): + def __init__(self): super(ReverseGradient, self).__init__() - self.hp_lambda = hp_lambda + self.hp_lambda = None def make_node(self, x): assert hasattr(self, '_props'), 'Your version of theano is too old to support __props__.' @@ -1718,3 +1717,8 @@ def grad(self, input, output_gradients): def infer_shape(self, node, i0_shapes): return i0_shapes + +_reverse_gradient = ReverseGradient() +def reverse_gradient(x, hp_lambda): + _reverse_gradient.hp_lambda = hp_lambda + return _reverse_gradient(x) diff --git a/keras/engine/training.py b/keras/engine/training.py index cf9bfaaffdc7..745ed5ad5f5e 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -1216,7 +1216,7 @@ def train_on_batch(self, x, y, from this class during training. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. - check_batch_dim: Whether to check batch dimensions for sanity. + check_batch_dim: Check batch dimensions for consistency. (default is True) # Returns Scalar training loss (if the model has a single output and no metrics) diff --git a/keras/layers/core.py b/keras/layers/core.py index 46091fb04439..337c186ffa10 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -1222,22 +1222,20 @@ class GradientReversal(Layer): # Arguments: hp_lambda: Scalar to multiply the flipped gradient. ''' - def __init__(self, hp_lambda, **kwargs): + def __init__(self, **kwargs): super(GradientReversal, self).__init__(**kwargs) - self.hp_lambda = K.variable(hp_lambda) self.supports_masking = False - self.op = K.ReverseGradient(self.hp_lambda) def build(self, input_shape): self.trainable_weights = [] - def call(self, x, mask=None): - return self.op(x) + def call(self, x, hp_lambda, mask=None): + return K.reverse_gradient(x, hp_lambda) def get_output_shape_for(self, input_shape): return input_shape def get_config(self): - config = {'hp_lambda': self.hp_lambda} + config = {} base_config = super(GradientReversal, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index 30ad1f6f17c5..63ae6658e079 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -260,8 +260,8 @@ def test_gradient_reversal(self): yth = xth ** 2 ytf = xtf ** 2 - lossth = KTH.ReverseGradient(lth)(KTH.sum(yth)) - losstf = KTF.ReverseGradient(ltf)(KTF.sum(ytf)) + lossth = KTH.reverse_gradient(KTH.sum(yth), lth) + losstf = KTF.reverse_gradient(KTF.sum(ytf), ltf) gradth = KTH.gradients(lossth, [xth]) gradtf = KTF.gradients(losstf, [xtf]) From 777f849b005160d8ccec6e747b3d0af185b93b89 Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Wed, 26 Oct 2016 23:28:28 +1100 Subject: [PATCH 6/8] pep8 --- keras/backend/tensorflow_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 6ff2fbcd1c62..8aa16510fb73 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1935,6 +1935,7 @@ def reverse_gradient(X, hp_lambda): reverse_gradient.num_calls = 1 grad_name = "GradientReversal%d" % reverse_gradient.num_calls + @tf.python.framework.ops.RegisterGradient(grad_name) def _flip_gradients(op, grad): return [tf.neg(grad) * hp_lambda] From 7bdc36365075b86ec4c5296940b9a506ccdd62f9 Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Wed, 26 Oct 2016 23:37:50 +1100 Subject: [PATCH 7/8] Fix up comment a bit --- keras/engine/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/engine/training.py b/keras/engine/training.py index 745ed5ad5f5e..ff2ff2880787 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -1216,7 +1216,7 @@ def train_on_batch(self, x, y, from this class during training. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. - check_batch_dim: Check batch dimensions for consistency. (default is True) + check_batch_dim: Check batch dimensions for shape consistency. (default is True) # Returns Scalar training loss (if the model has a single output and no metrics) From 3c920912e785c969b224c419f83ea0f30124f7f5 Mon Sep 17 00:00:00 2001 From: Vanush Vaswani Date: Wed, 26 Oct 2016 23:46:03 +1100 Subject: [PATCH 8/8] Fix up comment --- keras/layers/core.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/keras/layers/core.py b/keras/layers/core.py index 337c186ffa10..0d82ee5ae09d 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -1216,12 +1216,7 @@ def get_config(self): class GradientReversal(Layer): - ''' - Flip the sign of gradient during training. - - # Arguments: - hp_lambda: Scalar to multiply the flipped gradient. - ''' + '''Flip the sign of gradient during training.''' def __init__(self, **kwargs): super(GradientReversal, self).__init__(**kwargs) self.supports_masking = False