!12517 enhance for ctpn performance

From: @qujianwei
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
pull/12517/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 264f265de8

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FasterRcnn anchor generator.""" """CTPN anchor generator."""
import numpy as np import numpy as np
class AnchorGenerator(): class AnchorGenerator():
"""Anchor generator for FasterRcnn.""" """Anchor generator for CTPN."""
def __init__(self, config): def __init__(self, config):
"""Anchor generator init method.""" """Anchor generator init method."""
self.base_size = config.anchor_base self.base_size = config.anchor_base

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FasterRcnn positive and negative sample screening for RPN.""" """CTPN positive and negative sample screening for RPN."""
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FasterRcnn proposal generator.""" """CTPN proposal generator."""
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn

@ -49,25 +49,24 @@ class RpnRegClsBlock(nn.Cell):
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16) self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16) self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16) self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
self.shape1 = (config.num_step, config.rnn_batch_size, -1) self.shape1 = (-1, config.num_step, config.rnn_batch_size)
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step) self.shape2 = (config.batch_size, -1, config.rnn_batch_size, config.num_step)
self.transpose = P.Transpose() self.transpose = P.Transpose()
self.print = P.Print() self.print = P.Print()
self.dropout = nn.Dropout(0.8)
def construct(self, x): def construct(self, x):
x = self.reshape(x, self.shape) x = self.reshape(x, self.shape)
x = self.lstm_fc(x) x = self.lstm_fc(x)
x1 = self.rpn_cls(x) x1 = self.rpn_cls(x)
x1 = self.transpose(x1, (1, 0))
x1 = self.reshape(x1, self.shape1) x1 = self.reshape(x1, self.shape1)
x1 = self.transpose(x1, (2, 1, 0)) x1 = self.transpose(x1, (0, 2, 1))
x1 = self.reshape(x1, self.shape2) x1 = self.reshape(x1, self.shape2)
x1 = self.transpose(x1, (1, 0, 2, 3))
x2 = self.rpn_reg(x) x2 = self.rpn_reg(x)
x2 = self.transpose(x2, (1, 0))
x2 = self.reshape(x2, self.shape1) x2 = self.reshape(x2, self.shape1)
x2 = self.transpose(x2, (2, 1, 0)) x2 = self.transpose(x2, (0, 2, 1))
x2 = self.reshape(x2, self.shape2) x2 = self.reshape(x2, self.shape2)
x2 = self.transpose(x2, (1, 0, 2, 3))
return x1, x2 return x1, x2
class RPN(nn.Cell): class RPN(nn.Cell):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FasterRcnn dataset""" """CTPN dataset"""
from __future__ import division from __future__ import division
import os import os
import numpy as np import numpy as np

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FasterRcnn training network wrapper.""" """CTPN training network wrapper."""
import time import time
import numpy as np import numpy as np
@ -82,7 +82,7 @@ class LossCallBack(Callback):
loss_file.close() loss_file.close()
class LossNet(nn.Cell): class LossNet(nn.Cell):
"""FasterRcnn loss method""" """CTPN loss method"""
def construct(self, x1, x2, x3): def construct(self, x1, x2, x3):
return x1 return x1

Loading…
Cancel
Save