diff --git a/mindspore/lite/examples/export_models/models/effnet.py b/mindspore/lite/examples/export_models/models/effnet.py index 08ed460a36..cc550ea62d 100755 --- a/mindspore/lite/examples/export_models/models/effnet.py +++ b/mindspore/lite/examples/export_models/models/effnet.py @@ -25,24 +25,20 @@ def weight_variable(): return TruncatedNormal(0.02) -def _make_divisible(v, divisor, min_value=None): +def _make_value_divisible(value, factor, min_value=None): """ - This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - :param v: - :param divisor: - iparam min_value: - :return: + :param v: value to process + :param factor: divisor + :param min_value: new value always greater than the min_value + :return: new value """ if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v + min_value = factor + new_value = max(int(value + factor / 2) // factor * factor, min_value) + if new_value < value * 0.9: + new_value += factor + return new_value class Swish(nn.Cell): def __init__(self): @@ -69,7 +65,7 @@ class SELayer(nn.Cell): """SELayer""" def __init__(self, channel, reduction=4): super().__init__() - reduced_chs = _make_divisible(channel/reduction, 1) + reduced_chs = _make_value_divisible(channel/reduction, 1) self.avg_pool = AdaptiveAvgPool(output_size=(1, 1)) weight = weight_variable() self.conv_reduce = nn.Conv2d(in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True, @@ -152,7 +148,7 @@ class InvertedResidual(nn.Cell): def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio): super().__init__() assert stride in [1, 2] - mid_chs: int = _make_divisible(in_chs * expansion, 1) + mid_chs: int = _make_value_divisible(in_chs * expansion, 1) self.has_residual = (in_chs == out_chs and stride == 1) self.drop_connect_rate = 0