-
Notifications
You must be signed in to change notification settings - Fork 2
/
tf_utils.py
96 lines (93 loc) · 3.68 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# encoding: utf-8
'''
@author: wanjinchang
@license: (C) Copyright 2013-2017, Node Supply Chain Manager Corporation Limited.
@contact: wanjinchang1991@gmail.com
@software: pycharm
@file: tf_utils.py
@time: 18-3-8 下午3:43
@desc:
'''
import tensorflow as tf
def configure_learning_rate(flags, num_samples_per_epoch, global_step):
"""Configures the learning rate.
Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
"""
decay_steps = int(num_samples_per_epoch / flags.batch_size *
flags.num_epochs_per_decay)
if flags.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(flags.learning_rate,
global_step,
decay_steps,
flags.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif flags.learning_rate_decay_type == 'fixed':
return tf.constant(flags.learning_rate, name='fixed_learning_rate')
elif flags.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(flags.learning_rate,
global_step,
decay_steps,
flags.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
else:
raise ValueError('learning_rate_decay_type [%s] was not recognized',
flags.learning_rate_decay_type)
def configure_optimizer(flags, learning_rate):
"""Configures the optimizer used for training.
Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
"""
if flags.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=flags.adadelta_rho,
epsilon=flags.opt_epsilon
)
elif flags.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=flags.adagrad_initial_accumulator_value
)
elif flags.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=flags.adam_beta1,
beta2=flags.adam_beta2,
epsilon=flags.opt_epsilon
)
elif flags.optimizer == 'ftrl':
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=flags.ftrl_learning_rate_power,
initial_accumulator_value=flags.ftrl_initial_accumulator_value,
l1_regularization_strength=flags.ftrl_l1,
l2_regularization_strength=flags.ftrl_l2
)
elif flags.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=flags.momentum,
name='Momentum'
)
elif flags.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=flags.rmsprop_decay,
momentum=flags.rmsprop_momentum,
epsilon=flags.opt_epsilon
)
elif flags.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError('Optimizef [%s] was not recognized', flags.optimizer)
return optimizer