Skip to content

Commit

Permalink
Added tf.distribute.MirroredStrategy() test to estimator_distributed_…
Browse files Browse the repository at this point in the history
…test_runner
  • Loading branch information
chamorajg committed Jun 8, 2019
1 parent 3d8cd3a commit b21ed49
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
13 changes: 13 additions & 0 deletions adanet/core/estimator_distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,19 @@ def _wait_for_processes(self, wait_processes, kill_processes, timeout_secs):
"num_ps":
3,
},
{
"testcase_name":
"estimator_with_distributed_mirrored_strategy_{}_five_workers_three_ps"
.format(placement),
"estimator":
"estimator_with_distributed_mirrored_strategy",
"placement_strategy":
placement,
"num_workers":
5,
"num_ps":
3,
},
] for placement in ["replication", "round_robin"]]))
# pylint: enable=g-complex-comprehension
def test_distributed_training(self,
Expand Down
30 changes: 30 additions & 0 deletions adanet/core/estimator_distributed_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"estimator",
"autoensemble",
"autoensemble_trees_multiclass",
"estimator_with_distributed_mirrored_strategy"
], "The estimator type to train.")

flags.DEFINE_enum("placement_strategy", "replication", [
Expand Down Expand Up @@ -288,6 +289,35 @@ def tree_loss_fn(labels, logits):
estimator = AutoEnsembleEstimator(
head=head, candidate_pool=candidate_pool, **kwargs)

elif FLAGS.estimator_type == "estimator_with_distributed_mirrored_strategy":
def _model_fn(features, labels, mode):
layer = tf.layers.Dense(1)
logits = layer(features)

if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {"logits": logits}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)

loss = tf.losses.mean_squared_error(labels=labels, predictions=tf.reshape(logits,[]))

if mode == tf.estimator.ModeKeys.EVAL:
tf.estimator.EstimatorSpec(mode, loss=loss)

if mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
return tf.EstimatorSpec(mode, loss=loss, train_op=train_op)

def _input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
labels = tf.data.Dataset.from_tensors(1.).repeat(100)
return tf.data.Dataset.zip((features, labels))

distribution = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=_model_fn, config=config)
classifier.train(input_fn=_input_fn)
classifier.evaluate(input_fn=_input_fn)

def input_fn():
input_features = {"x": tf.constant(features, name="x")}
input_labels = tf.constant(labels, name="y")
Expand Down

0 comments on commit b21ed49

Please sign in to comment.