Skip to content

Commit

Permalink
feat: upgrade ms to 2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Jan 17, 2024
1 parent 67ba9a2 commit d0f8a95
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 6 deletions.
16 changes: 10 additions & 6 deletions mindcv/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
Expand Down Expand Up @@ -329,23 +330,26 @@ def construct(self, x: List[Tensor]) -> List[Tensor]:
if self.num_branches == 1:
return [self.branches[0](x[0])]

x2 = []
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x2.append(self.branches[i](x[i]))

x_fuse = []

for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
y = x2[0] if i == 0 else self.fuse_layers[i][0](x2[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
y = y + x2[j]
elif j > i:
_, _, height, width = x[i].shape
t = self.fuse_layers[i][j](x[j])
_, _, height, width = x2[i].shape
t = self.fuse_layers[i][j](x2[j])
t = ops.cast(t, ms.float32)
t = ops.ResizeNearestNeighbor((height, width))(t)
t = ops.cast(t, ms.float16)
y = y + t
else:
y = y + self.fuse_layers[i][j](x[j])
y = y + self.fuse_layers[i][j](x2[j])
x_fuse.append(self.relu(y))

if not self.multi_scale_output:
Expand Down
21 changes: 21 additions & 0 deletions mindcv/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,27 @@ def __init__(
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
self.clip = clip

def get_lr(self):
"""
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.learning_rate:
current_dynamic_lr = learning_rate(self.global_step).reshape(())
lr += (current_dynamic_lr,)
else:
lr = self.learning_rate(self.global_step).reshape(())
if self._is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
return lr

@jit
def construct(self, gradients):
lr = self.get_lr()
Expand Down
21 changes: 21 additions & 0 deletions mindcv/optim/adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ def __init__(

self.weight_decay = Tensor(weight_decay, mstype.float32)

def get_lr(self):
"""
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.learning_rate:
current_dynamic_lr = learning_rate(self.global_step).reshape(())
lr += (current_dynamic_lr,)
else:
lr = self.learning_rate(self.global_step).reshape(())
if self._is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
return lr

@jit
def construct(self, gradients):
params = self._parameters
Expand Down
21 changes: 21 additions & 0 deletions mindcv/optim/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,27 @@ def __init__(
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
self.clip = clip

def get_lr(self):
"""
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.learning_rate:
current_dynamic_lr = learning_rate(self.global_step).reshape(())
lr += (current_dynamic_lr,)
else:
lr = self.learning_rate(self.global_step).reshape(())
if self._is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
return lr

@jit
def construct(self, gradients):
lr = self.get_lr()
Expand Down
21 changes: 21 additions & 0 deletions mindcv/optim/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ def __init__(
self.mu_schedule = Parameter(initializer(1, [1], ms.float32), name="mu_schedule")
self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")

def get_lr(self):
"""
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.learning_rate:
current_dynamic_lr = learning_rate(self.global_step).reshape(())
lr += (current_dynamic_lr,)
else:
lr = self.learning_rate(self.global_step).reshape(())
if self._is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
return lr

@jit
def construct(self, gradients):
lr = self.get_lr()
Expand Down

0 comments on commit d0f8a95

Please sign in to comment.