Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIPC]update tipc scripts and rm fluid api #11098

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 64 additions & 56 deletions ppocr/modeling/heads/rec_robustscanner_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py
Expand All @@ -28,6 +27,7 @@
import paddle.nn as nn
import paddle.nn.functional as F


class BaseDecoder(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
Expand All @@ -42,15 +42,17 @@ def forward(self,
feat,
out_enc,
label=None,
valid_ratios=None,
valid_ratios=None,
word_positions=None,
train_mode=True):
self.train_mode = train_mode

if train_mode:
return self.forward_train(feat, out_enc, label, valid_ratios, word_positions)
return self.forward_train(feat, out_enc, label, valid_ratios,
word_positions)
return self.forward_test(feat, out_enc, valid_ratios, word_positions)


class ChannelReductionEncoder(nn.Layer):
"""Change the channel number with a one by one convoluational layer.

Expand All @@ -59,14 +61,16 @@ class ChannelReductionEncoder(nn.Layer):
out_channels (int): Number of output channels.
"""

def __init__(self,
in_channels,
out_channels,
**kwargs):
def __init__(self, in_channels, out_channels, **kwargs):
super(ChannelReductionEncoder, self).__init__()

self.layer = nn.Conv2D(
in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal())
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=nn.initializer.XavierNormal())

def forward(self, feat):
"""
Expand All @@ -84,8 +88,8 @@ def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)

class DotProductAttentionLayer(nn.Layer):

class DotProductAttentionLayer(nn.Layer):
def __init__(self, dim_model=None):
super().__init__()

Expand All @@ -99,7 +103,7 @@ def forward(self, query, key, value, h, w, valid_ratios=None):
logits = paddle.reshape(logits, [n, c, h, w])
if valid_ratios is not None:
# cal mask of attention weight
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
with paddle.base.framework._stride_in_no_check_dy2st_diff():
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, int(w * valid_ratio + 0.5))
if valid_width < w:
Expand All @@ -113,6 +117,7 @@ def forward(self, query, key, value, h, w, valid_ratios=None):
glimpse = paddle.transpose(glimpse, (0, 2, 1))
return glimpse


class SequenceAttentionDecoder(BaseDecoder):
"""Sequence attention decoder for RobustScanner.

Expand Down Expand Up @@ -181,8 +186,8 @@ def __init__(self,
self.prediction = None
if not self.return_feature:
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(
dim_model if encode_value else dim_input, pred_num_classes)
self.prediction = nn.Linear(dim_model if encode_value else
dim_input, pred_num_classes)

def forward_train(self, feat, out_enc, targets, valid_ratios):
"""
Expand Down Expand Up @@ -243,12 +248,13 @@ def forward_test(self, feat, out_enc, valid_ratios):
seq_len = self.max_seq_len
batch_size = feat.shape[0]

decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
decode_sequence = (paddle.ones(
(batch_size, seq_len), dtype='int64') * self.start_idx)

outputs = []
for i in range(seq_len):
step_out = self.forward_test_step(feat, out_enc, decode_sequence,
i, valid_ratios)
step_out = self.forward_test_step(feat, out_enc, decode_sequence, i,
valid_ratios)
outputs.append(step_out)
max_idx = paddle.argmax(step_out, axis=1, keepdim=False)
if i < seq_len - 1:
Expand All @@ -274,7 +280,7 @@ def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
tokens at current time step.
"""

embed = self.embedding(decode_sequence)

n, c_enc, h, w = out_enc.shape
Expand Down Expand Up @@ -306,7 +312,6 @@ def forward_test_step(self, feat, out_enc, decode_sequence, current_step,


class PositionAwareLayer(nn.Layer):

def __init__(self, dim_model, rnn_layers=2):
super().__init__()

Expand Down Expand Up @@ -384,16 +389,16 @@ def __init__(self,

self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)

self.position_aware_module = PositionAwareLayer(
self.dim_model, rnn_layers)
self.position_aware_module = PositionAwareLayer(self.dim_model,
rnn_layers)

self.attention_layer = DotProductAttentionLayer()

self.prediction = None
if not self.return_feature:
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(
dim_model if encode_value else dim_input, pred_num_classes)
self.prediction = nn.Linear(dim_model if encode_value else
dim_input, pred_num_classes)

def _get_position_index(self, length, batch_size):
position_index_list = []
Expand All @@ -403,7 +408,8 @@ def _get_position_index(self, length, batch_size):
batch_position_index = paddle.stack(position_index_list, axis=0)
return batch_position_index

def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
def forward_train(self, feat, out_enc, targets, valid_ratios,
position_index):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
Expand All @@ -427,16 +433,16 @@ def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
assert c_feat == self.dim_input
_, len_q = targets.shape
assert len_q <= self.max_seq_len

position_out_enc = self.position_aware_module(out_enc)

query = self.embedding(position_index)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
if self.encode_value:
value = paddle.reshape(out_enc,(n, c_enc, h * w))
value = paddle.reshape(out_enc, (n, c_enc, h * w))
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
value = paddle.reshape(feat, (n, c_feat, h * w))

attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
Expand Down Expand Up @@ -467,14 +473,14 @@ def forward_test(self, feat, out_enc, valid_ratios, position_index):
assert c_feat == self.dim_input

position_out_enc = self.position_aware_module(out_enc)

query = self.embedding(position_index)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
if self.encode_value:
value = paddle.reshape(out_enc,(n, c_enc, h * w))
value = paddle.reshape(out_enc, (n, c_enc, h * w))
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
value = paddle.reshape(feat, (n, c_feat, h * w))

attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
Expand All @@ -484,8 +490,8 @@ def forward_test(self, feat, out_enc, valid_ratios, position_index):

return self.prediction(attn_out)

class RobustScannerFusionLayer(nn.Layer):

class RobustScannerFusionLayer(nn.Layer):
def __init__(self, dim_model, dim=-1):
super(RobustScannerFusionLayer, self).__init__()

Expand All @@ -500,6 +506,7 @@ def forward(self, x0, x1):
output = F.glu(output, self.dim)
return output


class RobustScannerDecoder(BaseDecoder):
"""Decoder for RobustScanner.

Expand Down Expand Up @@ -561,8 +568,7 @@ def __init__(self,
padding_idx=padding_idx,
dropout=hybrid_decoder_dropout,
encode_value=encode_value,
return_feature=True
)
return_feature=True)

# init position decoder
self.position_decoder = PositionAttentionDecoder(
Expand All @@ -573,9 +579,7 @@ def __init__(self,
max_seq_len=max_seq_len,
mask=mask,
encode_value=encode_value,
return_feature=True
)

return_feature=True)

self.fusion_module = RobustScannerFusionLayer(
self.dim_model if encode_value else dim_input)
Expand All @@ -584,7 +588,8 @@ def __init__(self,
self.prediction = nn.Linear(dim_model if encode_value else dim_input,
pred_num_classes)

def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
def forward_train(self, feat, out_enc, target, valid_ratios,
word_positions):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
Expand All @@ -599,8 +604,8 @@ def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
"""
hybrid_glimpse = self.hybrid_decoder.forward_train(
feat, out_enc, target, valid_ratios)
hybrid_glimpse = self.hybrid_decoder.forward_train(feat, out_enc,
target, valid_ratios)
position_glimpse = self.position_decoder.forward_train(
feat, out_enc, target, valid_ratios, word_positions)

Expand All @@ -625,7 +630,8 @@ def forward_test(self, feat, out_enc, valid_ratios, word_positions):
seq_len = self.max_seq_len
batch_size = feat.shape[0]

decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
decode_sequence = (paddle.ones(
(batch_size, seq_len), dtype='int64') * self.start_idx)

position_glimpse = self.position_decoder.forward_test(
feat, out_enc, valid_ratios, word_positions)
Expand All @@ -649,28 +655,30 @@ def forward_test(self, feat, out_enc, valid_ratios, word_positions):

return outputs


class RobustScannerHead(nn.Layer):
def __init__(self,
out_channels, # 90 + unknown + start + padding
in_channels,
enc_outchannles=128,
hybrid_dec_rnn_layers=2,
hybrid_dec_dropout=0,
position_dec_rnn_layers=2,
start_idx=0,
max_text_length=40,
mask=True,
padding_idx=None,
encode_value=False,
**kwargs):
def __init__(
self,
out_channels, # 90 + unknown + start + padding
in_channels,
enc_outchannles=128,
hybrid_dec_rnn_layers=2,
hybrid_dec_dropout=0,
position_dec_rnn_layers=2,
start_idx=0,
max_text_length=40,
mask=True,
padding_idx=None,
encode_value=False,
**kwargs):
super(RobustScannerHead, self).__init__()

# encoder module
self.encoder = ChannelReductionEncoder(
in_channels=in_channels, out_channels=enc_outchannles)

# decoder module
self.decoder =RobustScannerDecoder(
self.decoder = RobustScannerDecoder(
num_classes=out_channels,
dim_input=in_channels,
dim_model=enc_outchannles,
Expand All @@ -693,18 +701,18 @@ def forward(self, inputs, targets=None):

if len(targets) > 1:
valid_ratios = targets[-2]

if self.training:
label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(
inputs, out_enc, label, valid_ratios, word_positions)
final_out = self.decoder(inputs, out_enc, label, valid_ratios,
word_positions)
if not self.training:
final_out = self.decoder(
inputs,
out_enc,
label=None,
valid_ratios=valid_ratios,
valid_ratios=valid_ratios,
word_positions=word_positions,
train_mode=False)
return final_out
4 changes: 4 additions & 0 deletions test_tipc/test_train_inference_python_npu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ if [ $modelname == "rec_r31_sar" ] || [ $modelname == "rec_mtb_nrtr" ]; then
sed -i "s/gpu_list:0|0,1/gpu_list:0,1/g" $FILENAME
sed -i "s/Global.use_npu:True|True/Global.use_npu:True/g" $FILENAME
fi
if [ $modelname == "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
sed -i '18s/$/ -o Global.use_gpu=False/' $FILENAME
sed -i '32s/$/ Global.use_gpu=False/' $FILENAME
fi

# replace training config file
grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \
Expand Down
4 changes: 4 additions & 0 deletions test_tipc/test_train_inference_python_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ if [ $modelname == "rec_r31_sar" ] || [ $modelname == "rec_mtb_nrtr" ]; then
sed -i "s/gpu_list:0|0,1/gpu_list:0,1/g" $FILENAME
sed -i "s/Global.use_xpu:True|True/Global.use_xpu:True/g" $FILENAME
fi
if [ $modelname == "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
sed -i '18s/$/ -o Global.use_gpu=False/' $FILENAME
sed -i '32s/$/ Global.use_gpu=False/' $FILENAME
fi

# replace training config file
grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \
Expand Down