This repository has been archived by the owner on Dec 17, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 855
/
trainer.py
293 lines (232 loc) · 9.93 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.contrib import summary
N_CLASSES = 10
# ## Gradient Reversal Layer
#
# When applied to a tensor this layer is the identity map, but it reverses
# the sign of the gradient, and optionally multiplies the reversed gradient
# by a weight.
#
# For details, see [Domain-Adversarial Training of Neural Networks](https://arxiv.org/abs/1505.07818).
#
class GradientReversalLayer(tf.layers.Layer):
def __init__(self, weight=1.0):
super(GradientReversalLayer, self).__init__()
self.weight = weight
def call(self, input_):
@tf.custom_gradient
def _call(input_):
def reversed_gradient(output_grads):
return self.weight * tf.negative(output_grads)
return input_, reversed_gradient
return _call(input_)
# ## The model function
# The network consists of 3 sub-networks:
#
# * Feature extractor: extracts internal representation for both the source and target distributions.
#
# * Label predictor: predicts label from the extracted features.
#
# * Domain classifier: classifies the origin (`source` or `target`) of the extracted features.
#
#
# Both the label predictor and the domain classifier will try to minimize
# classification loss, but the gradients backpropagated from the domain
# classifier to the feature extractor have their signs reversed.
#
#
# This model function also shows how to use `host_call` to output summaries.
#
def model_fn(features, labels, mode, params):
source = features['source']
target = features['target']
onehot_labels = tf.one_hot(labels, N_CLASSES)
global_step = tf.train.get_global_step()
# In this sample we use dense layers for each of the sub-networks.
feature_extractor = tf.layers.Dense(7, activation=tf.nn.sigmoid)
label_predictor_logits = tf.layers.Dense(N_CLASSES)
# There are two domains, 0: source and 1: target
domain_classifier_logits = tf.layers.Dense(2)
source_features = feature_extractor(source)
target_features = feature_extractor(target)
# Apply the gradient reversal layer to target features
gr_weight = params['gr_weight']
gradient_reversal = GradientReversalLayer(gr_weight)
target_features = gradient_reversal(target_features)
# The predictions are the predicted labels from the `target` distribution.
predictions = tf.nn.softmax(label_predictor_logits(target_features))
loss = None
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
# define loss
label_prediction_loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels,
logits=label_predictor_logits(source_features)
)
# There are two domains, 0: source and 1: target
domain_labels = tf.concat((tf.zeros(source.shape[0], dtype=tf.int32), tf.ones(target.shape[0], dtype=tf.int32)), axis=0)
domain_onehot_labels = tf.one_hot(domain_labels, 2)
source_target_features = tf.concat([source_features, target_features], axis=0)
domain_classification_loss = tf.losses.softmax_cross_entropy(
onehot_labels=domain_onehot_labels,
logits=domain_classifier_logits(source_target_features)
)
lambda_ = params['lambda']
loss = label_prediction_loss + lambda_ * domain_classification_loss
# define train_op
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.05)
# wrapper to make the optimizer work with TPUs
if params['use_tpu']:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, global_step=global_step)
if params['use_tpu']:
# Use host_call to log the losses on the CPU
def host_call_fn(gs, lpl, dcl, ls):
gs = gs[0]
with summary.create_file_writer(params['model_dir'], max_queue=params['save_checkpoints_steps']).as_default():
with summary.always_record_summaries():
summary.scalar('label_prediction_loss', lpl[0], step=gs)
summary.scalar('domain_classification_loss', dcl[0], step=gs)
summary.scalar('loss', ls[0], step=gs)
return summary.all_summary_ops()
# host_call's arguments must be at least 1D
gs_t = tf.reshape(global_step, [1])
lpl_t = tf.reshape(label_prediction_loss, [1])
dcl_t = tf.reshape(domain_classification_loss, [1])
ls_t = tf.reshape(loss, [1])
host_call = (host_call_fn, [gs_t, lpl_t, dcl_t, ls_t])
# TPU version of EstimatorSpec
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
host_call=host_call)
else:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
# ## The input function
# There are two input data sets, `source` is labeled and `target` is unlabeled.
def train_input_fn(params={}):
# source distribution: labeled data
source = np.random.rand(100, 5)
labels = np.random.randint(0, N_CLASSES, 100)
# target distribution: unlabeled data
target = np.random.rand(100, 5)
source_tensor = tf.constant(source, dtype=tf.float32)
labels_tensor = tf.constant(labels, dtype=tf.int32)
target_tensor = tf.constant(target, dtype=tf.float32)
# shuffle source and target separately
source_labels_dataset = tf.data.Dataset.from_tensor_slices((source_tensor, labels_tensor)).repeat().shuffle(32)
target_dataset = tf.data.Dataset.from_tensor_slices(target_tensor).repeat().shuffle(32)
# zip them together to set shapes
dataset = tf.data.Dataset.zip((source_labels_dataset, target_dataset))
# TPUEstimator passes params when calling input_fn
batch_size = params.get('batch_size', 16)
dataset = dataset.batch(batch_size, drop_remainder=True)
# TPUs need to know all dimensions when the graph is built
# Datasets know the batch size only when the graph is run
def set_shapes_and_format(source_labels, target):
source, labels = source_labels
source_shape = source.get_shape().merge_with([batch_size, None])
labels_shape = labels.get_shape().merge_with([batch_size])
target_shape = target.get_shape().merge_with([batch_size, None])
source.set_shape(source_shape)
labels.set_shape(labels_shape)
target.set_shape(target_shape)
# Also format the dataset with a dict for features
features = {'source': source, 'target': target}
return features, labels
dataset = dataset.map(set_shapes_and_format)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
return dataset
def main(args):
# pass the args as params so the model_fn can use
# the TPU specific args
params = vars(args)
if args.use_tpu:
# additional configs required for using TPUs
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu)
tpu_config = tf.contrib.tpu.TPUConfig(
num_shards=8, # using Cloud TPU v2-8
iterations_per_loop=args.save_checkpoints_steps)
# use the TPU version of RunConfig
config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=args.model_dir,
tpu_config=tpu_config,
save_checkpoints_steps=args.save_checkpoints_steps,
save_summary_steps=100)
# TPUEstimator
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
config=config,
params=params,
train_batch_size=args.train_batch_size,
eval_batch_size=32,
export_to_tpu=False)
else:
config = tf.estimator.RunConfig(model_dir=args.model_dir)
estimator = tf.estimator.Estimator(
model_fn,
config=config,
params=params)
estimator.train(train_input_fn, max_steps=args.max_steps)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--model-dir',
type=str,
default='/tmp/tpu-template',
help='Location to write checkpoints and summaries to. Must be a GCS URI when using Cloud TPU.')
parser.add_argument(
'--max-steps',
type=int,
default=1000,
help='The total number of steps to train the model.')
parser.add_argument(
'--train-batch-size',
type=int,
default=16,
help='The training batch size. The training batch is divided evenly across the TPU cores.')
parser.add_argument(
'--save-checkpoints-steps',
type=int,
default=100,
help='The number of training steps before saving each checkpoint.')
parser.add_argument(
'--use-tpu',
action='store_true',
help='Whether to use TPU.')
parser.add_argument(
'--tpu',
default=None,
help='The name or GRPC URL of the TPU node. Leave it as `None` when training on AI Platform.')
parser.add_argument(
'--gr-weight',
default=1.0,
help='The weight used in the gradient reversal layer.')
parser.add_argument(
'--lambda',
default=1.0,
help='The trade-off between label_prediction_loss and domain_classification_loss.')
args, _ = parser.parse_known_args()
main(args)