|
|
|
@ -356,19 +356,16 @@ class SsdMobilenetV1Fpn(nn.Cell):
|
|
|
|
|
Examples:backbone
|
|
|
|
|
SsdMobilenetV1Fpn(config, True).
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, config, is_training=True):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super(SsdMobilenetV1Fpn, self).__init__()
|
|
|
|
|
self.multi_box = WeightSharedMultiBox(config)
|
|
|
|
|
self.is_training = is_training
|
|
|
|
|
if not is_training:
|
|
|
|
|
self.activation = P.Sigmoid()
|
|
|
|
|
|
|
|
|
|
self.activation = P.Sigmoid()
|
|
|
|
|
self.feature_extractor = mobilenet_v1_fpn(config)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
features = self.feature_extractor(x)
|
|
|
|
|
pred_loc, pred_label = self.multi_box(features)
|
|
|
|
|
if not self.is_training:
|
|
|
|
|
if not self.training:
|
|
|
|
|
pred_label = self.activation(pred_label)
|
|
|
|
|
pred_loc = F.cast(pred_loc, mstype.float32)
|
|
|
|
|
pred_label = F.cast(pred_label, mstype.float32)
|
|
|
|
|