enhance for ctpn preformance

pull/12517/head
qujianwei 4 years ago
parent 8aba5d8f57
commit a95abd0cef

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn anchor generator."""
"""CTPN anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for FasterRcnn."""
"""Anchor generator for CTPN."""
def __init__(self, config):
"""Anchor generator init method."""
self.base_size = config.anchor_base

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn positive and negative sample screening for RPN."""
"""CTPN positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn proposal generator."""
"""CTPN proposal generator."""
import numpy as np
import mindspore.nn as nn

@ -49,25 +49,24 @@ class RpnRegClsBlock(nn.Cell):
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
self.shape1 = (-1, config.num_step, config.rnn_batch_size)
self.shape2 = (config.batch_size, -1, config.rnn_batch_size, config.num_step)
self.transpose = P.Transpose()
self.print = P.Print()
self.dropout = nn.Dropout(0.8)
def construct(self, x):
x = self.reshape(x, self.shape)
x = self.lstm_fc(x)
x1 = self.rpn_cls(x)
x1 = self.transpose(x1, (1, 0))
x1 = self.reshape(x1, self.shape1)
x1 = self.transpose(x1, (2, 1, 0))
x1 = self.transpose(x1, (0, 2, 1))
x1 = self.reshape(x1, self.shape2)
x1 = self.transpose(x1, (1, 0, 2, 3))
x2 = self.rpn_reg(x)
x2 = self.transpose(x2, (1, 0))
x2 = self.reshape(x2, self.shape1)
x2 = self.transpose(x2, (2, 1, 0))
x2 = self.transpose(x2, (0, 2, 1))
x2 = self.reshape(x2, self.shape2)
x2 = self.transpose(x2, (1, 0, 2, 3))
return x1, x2
class RPN(nn.Cell):

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FasterRcnn dataset"""
"""CTPN dataset"""
from __future__ import division
import os
import numpy as np

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn training network wrapper."""
"""CTPN training network wrapper."""
import time
import numpy as np
@ -82,7 +82,7 @@ class LossCallBack(Callback):
loss_file.close()
class LossNet(nn.Cell):
"""FasterRcnn loss method"""
"""CTPN loss method"""
def construct(self, x1, x2, x3):
return x1

Loading…
Cancel
Save