|
|
@ -49,25 +49,24 @@ class RpnRegClsBlock(nn.Cell):
|
|
|
|
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
|
|
|
|
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
|
|
|
|
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
|
|
|
|
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
|
|
|
|
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
|
|
|
|
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
|
|
|
|
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
|
|
|
|
self.shape1 = (-1, config.num_step, config.rnn_batch_size)
|
|
|
|
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
|
|
|
|
self.shape2 = (config.batch_size, -1, config.rnn_batch_size, config.num_step)
|
|
|
|
self.transpose = P.Transpose()
|
|
|
|
self.transpose = P.Transpose()
|
|
|
|
self.print = P.Print()
|
|
|
|
self.print = P.Print()
|
|
|
|
self.dropout = nn.Dropout(0.8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
def construct(self, x):
|
|
|
|
x = self.reshape(x, self.shape)
|
|
|
|
x = self.reshape(x, self.shape)
|
|
|
|
x = self.lstm_fc(x)
|
|
|
|
x = self.lstm_fc(x)
|
|
|
|
x1 = self.rpn_cls(x)
|
|
|
|
x1 = self.rpn_cls(x)
|
|
|
|
|
|
|
|
x1 = self.transpose(x1, (1, 0))
|
|
|
|
x1 = self.reshape(x1, self.shape1)
|
|
|
|
x1 = self.reshape(x1, self.shape1)
|
|
|
|
x1 = self.transpose(x1, (2, 1, 0))
|
|
|
|
x1 = self.transpose(x1, (0, 2, 1))
|
|
|
|
x1 = self.reshape(x1, self.shape2)
|
|
|
|
x1 = self.reshape(x1, self.shape2)
|
|
|
|
x1 = self.transpose(x1, (1, 0, 2, 3))
|
|
|
|
|
|
|
|
x2 = self.rpn_reg(x)
|
|
|
|
x2 = self.rpn_reg(x)
|
|
|
|
|
|
|
|
x2 = self.transpose(x2, (1, 0))
|
|
|
|
x2 = self.reshape(x2, self.shape1)
|
|
|
|
x2 = self.reshape(x2, self.shape1)
|
|
|
|
x2 = self.transpose(x2, (2, 1, 0))
|
|
|
|
x2 = self.transpose(x2, (0, 2, 1))
|
|
|
|
x2 = self.reshape(x2, self.shape2)
|
|
|
|
x2 = self.reshape(x2, self.shape2)
|
|
|
|
x2 = self.transpose(x2, (1, 0, 2, 3))
|
|
|
|
|
|
|
|
return x1, x2
|
|
|
|
return x1, x2
|
|
|
|
|
|
|
|
|
|
|
|
class RPN(nn.Cell):
|
|
|
|
class RPN(nn.Cell):
|
|
|
|