-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathscheduler.py
30 lines (23 loc) · 922 Bytes
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
class LinearScheduler:
def __init__(self, initial_value, final_step, name):
self.final_step = final_step
self.initial_value = initial_value
self.variable = tf.Variable(initial_value, name=name)
self.decayed_ph = tf.placeholder(tf.float32)
self.decay_op = self.variable.assign(self.decayed_ph)
def decay(self, step):
decay = 1.0 - (float(step) / self.final_step)
if decay < 0.0:
decay = 0.0
feed_dict = {self.decayed_ph: decay * self.initial_value}
tf.get_default_session().run(self.decay_op, feed_dict=feed_dict)
def get_variable(self):
return self.variable
class ConstantScheduler:
def __init__(self, initial_value, name):
self.variable = tf.Variable(initial_value, name=name)
def decay(self, step):
pass
def get_variable(self):
return self.variable