diff --git a/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py b/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py index c5c26c28eb..d30d8274b4 100644 --- a/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py +++ b/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""FasterRcnn anchor generator.""" +"""CTPN anchor generator.""" import numpy as np class AnchorGenerator(): - """Anchor generator for FasterRcnn.""" + """Anchor generator for CTPN.""" def __init__(self, config): """Anchor generator init method.""" self.base_size = config.anchor_base diff --git a/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py b/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py index 18d2aa1130..7ca88d3bc1 100644 --- a/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py +++ b/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""FasterRcnn positive and negative sample screening for RPN.""" +"""CTPN positive and negative sample screening for RPN.""" import numpy as np import mindspore.nn as nn diff --git a/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py b/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py index 34b187fbf7..f61af38e52 100644 --- a/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py +++ b/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""FasterRcnn proposal generator.""" +"""CTPN proposal generator.""" import numpy as np import mindspore.nn as nn diff --git a/model_zoo/official/cv/ctpn/src/CTPN/rpn.py b/model_zoo/official/cv/ctpn/src/CTPN/rpn.py index 46826c66fe..90d568a7c8 100644 --- a/model_zoo/official/cv/ctpn/src/CTPN/rpn.py +++ b/model_zoo/official/cv/ctpn/src/CTPN/rpn.py @@ -49,25 +49,24 @@ class RpnRegClsBlock(nn.Cell): self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16) self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16) self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16) - self.shape1 = (config.num_step, config.rnn_batch_size, -1) - self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step) + self.shape1 = (-1, config.num_step, config.rnn_batch_size) + self.shape2 = (config.batch_size, -1, config.rnn_batch_size, config.num_step) self.transpose = P.Transpose() self.print = P.Print() - self.dropout = nn.Dropout(0.8) def construct(self, x): x = self.reshape(x, self.shape) x = self.lstm_fc(x) x1 = self.rpn_cls(x) + x1 = self.transpose(x1, (1, 0)) x1 = self.reshape(x1, self.shape1) - x1 = self.transpose(x1, (2, 1, 0)) + x1 = self.transpose(x1, (0, 2, 1)) x1 = self.reshape(x1, self.shape2) - x1 = self.transpose(x1, (1, 0, 2, 3)) x2 = self.rpn_reg(x) + x2 = self.transpose(x2, (1, 0)) x2 = self.reshape(x2, self.shape1) - x2 = self.transpose(x2, (2, 1, 0)) + x2 = self.transpose(x2, (0, 2, 1)) x2 = self.reshape(x2, self.shape2) - x2 = self.transpose(x2, (1, 0, 2, 3)) return x1, x2 class RPN(nn.Cell): diff --git a/model_zoo/official/cv/ctpn/src/dataset.py b/model_zoo/official/cv/ctpn/src/dataset.py index cebe212b80..cdc4cc582f 100644 --- a/model_zoo/official/cv/ctpn/src/dataset.py +++ b/model_zoo/official/cv/ctpn/src/dataset.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FasterRcnn dataset""" +"""CTPN dataset""" from __future__ import division import os import numpy as np diff --git a/model_zoo/official/cv/ctpn/src/network_define.py b/model_zoo/official/cv/ctpn/src/network_define.py index d31586c7d5..f352720518 100644 --- a/model_zoo/official/cv/ctpn/src/network_define.py +++ b/model_zoo/official/cv/ctpn/src/network_define.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""FasterRcnn training network wrapper.""" +"""CTPN training network wrapper.""" import time import numpy as np @@ -82,7 +82,7 @@ class LossCallBack(Callback): loss_file.close() class LossNet(nn.Cell): - """FasterRcnn loss method""" + """CTPN loss method""" def construct(self, x1, x2, x3): return x1