|
|
|
@ -25,24 +25,20 @@ def weight_variable():
|
|
|
|
|
return TruncatedNormal(0.02)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_divisible(v, divisor, min_value=None):
|
|
|
|
|
def _make_value_divisible(value, factor, min_value=None):
|
|
|
|
|
"""
|
|
|
|
|
This function is taken from the original tf repo.
|
|
|
|
|
It ensures that all layers have a channel number that is divisible by 8
|
|
|
|
|
It can be seen here:
|
|
|
|
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
|
|
|
|
:param v:
|
|
|
|
|
:param divisor:
|
|
|
|
|
iparam min_value:
|
|
|
|
|
:return:
|
|
|
|
|
:param v: value to process
|
|
|
|
|
:param factor: divisor
|
|
|
|
|
:param min_value: new value always greater than the min_value
|
|
|
|
|
:return: new value
|
|
|
|
|
"""
|
|
|
|
|
if min_value is None:
|
|
|
|
|
min_value = divisor
|
|
|
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
|
|
|
# Make sure that round down does not go down by more than 10%.
|
|
|
|
|
if new_v < 0.9 * v:
|
|
|
|
|
new_v += divisor
|
|
|
|
|
return new_v
|
|
|
|
|
min_value = factor
|
|
|
|
|
new_value = max(int(value + factor / 2) // factor * factor, min_value)
|
|
|
|
|
if new_value < value * 0.9:
|
|
|
|
|
new_value += factor
|
|
|
|
|
return new_value
|
|
|
|
|
|
|
|
|
|
class Swish(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -69,7 +65,7 @@ class SELayer(nn.Cell):
|
|
|
|
|
"""SELayer"""
|
|
|
|
|
def __init__(self, channel, reduction=4):
|
|
|
|
|
super().__init__()
|
|
|
|
|
reduced_chs = _make_divisible(channel/reduction, 1)
|
|
|
|
|
reduced_chs = _make_value_divisible(channel/reduction, 1)
|
|
|
|
|
self.avg_pool = AdaptiveAvgPool(output_size=(1, 1))
|
|
|
|
|
weight = weight_variable()
|
|
|
|
|
self.conv_reduce = nn.Conv2d(in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True,
|
|
|
|
@ -152,7 +148,7 @@ class InvertedResidual(nn.Cell):
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert stride in [1, 2]
|
|
|
|
|
mid_chs: int = _make_divisible(in_chs * expansion, 1)
|
|
|
|
|
mid_chs: int = _make_value_divisible(in_chs * expansion, 1)
|
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1)
|
|
|
|
|
self.drop_connect_rate = 0
|
|
|
|
|
|
|
|
|
|