Skip to content

Commit

Permalink
Update PolynomialLR's doc and paramater (#8430)
Browse files Browse the repository at this point in the history
* update PolynomialLR doc, current_batch = min(decay_batch, current_batch)

* * update PolynomialLR doc, current_batch = min(decay_batch, current_batch)
* rename the steps to decay_batch in parameters

* update PolynomialLR test case

Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 20, 2022
1 parent a1e91da commit 3cbf392
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
14 changes: 8 additions & 6 deletions python/oneflow/nn/optimizer/polynomial_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ class PolynomialLR(LRScheduler):
.. math::
\begin{aligned}
& decay\_batch = min(decay\_batch, current\_batch) \\
& current\_batch = min(decay\_batch, current\_batch) \\
& learning\_rate = (base\_lr-end\_lr)*(1-\frac{current\_batch}{decay\_batch})^{power}+end\_lr
\end{aligned}
Args:
optimizer (Optimizer): Wrapper optimizer.
steps (int): The decayed steps.
decay_batch (int): The decayed steps.
end_learning_rate (float, optional): The final learning rate. Defaults to 0.0001.
power (float, optional): The power of polynomial. Defaults to 1.0.
cycle (bool, optional): If cycle is True, the scheduler will decay the learning rate every decay steps. Defaults to False.
Expand All @@ -55,7 +55,7 @@ class PolynomialLR(LRScheduler):
...
polynomial_scheduler = flow.optim.lr_scheduler.PolynomialLR(
optimizer, steps=5, end_learning_rate=0.00001, power=2
optimizer, decay_batch=5, end_learning_rate=0.00001, power=2
)
for epoch in range(num_epoch):
Expand All @@ -66,15 +66,17 @@ class PolynomialLR(LRScheduler):
def __init__(
self,
optimizer,
steps: int,
decay_batch: int,
end_learning_rate: float = 0.0001,
power: float = 1.0,
cycle: bool = False,
last_step: int = -1,
verbose: bool = False,
):
assert steps > 0, f"steps must greater than zero, but got {steps}"
self.max_decay_steps = steps
assert (
decay_batch > 0
), f"decay_batch must greater than zero, but got {decay_batch}"
self.max_decay_steps = decay_batch
self.end_learning_rate = end_learning_rate
self.power = power
self.cycle = cycle
Expand Down
4 changes: 2 additions & 2 deletions python/oneflow/test/graph/test_graph_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_polynomial_lr(self):
base_lr=0.1,
iters=20,
lr_scheduler=flow.optim.lr_scheduler.PolynomialLR,
steps=20,
decay_batch=20,
end_learning_rate=1e-5,
power=2.0,
atol=1e-5,
Expand All @@ -191,7 +191,7 @@ def test_polynomial_lr(self):
base_lr=0.01,
iters=20,
lr_scheduler=flow.optim.lr_scheduler.PolynomialLR,
steps=20,
decay_batch=20,
end_learning_rate=1e-4,
power=1.0,
cycle=True,
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/graph/test_graph_lrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _lr_fn(parameters):
of_sgd = flow.optim.SGD(parameters, lr=0.001)

lr = flow.optim.lr_scheduler.PolynomialLR(
of_sgd, steps=10, end_learning_rate=0.00001, power=2, cycle=True
of_sgd, decay_batch=10, end_learning_rate=0.00001, power=2, cycle=True
)
return of_sgd, lr

Expand Down

0 comments on commit 3cbf392

Please sign in to comment.