Add hard swish op (new op) (#19001)

* add hard_swish activation op (new op)
test=develop

* remove redundancy files

* modify document content of HardSwish OP

* add API test in test_layers.py

* add dynamic_graph for test_hard_swish
padding_in_crf
huangjun12 6 years ago committed by SunGaofeng
parent bce72c7fea
commit 20f18930ae

@ -270,6 +270,7 @@ paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddi
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6'))

@ -573,6 +573,32 @@ $$out = \\frac{x}{1 + e^{- \beta \ x}}$$
}
};
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of HardSwish operator");
AddOutput("Out", "Output of HardSwish operator");
AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("scale", "The scale parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f);
AddComment(R"DOC(
HardSwish Activation Operator.
The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).
$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$
The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.
)DOC");
}
};
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);

@ -919,6 +919,51 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// HardSwish = min(max(0, x+3), 6) * x / 6
template <typename T>
struct HardSwishFunctor : public BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = (x + static_cast<T>(offset))
.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(threshold)) *
x / static_cast<T>(scale);
}
};
template <typename T>
struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = ((x + static_cast<T>(offset)) < static_cast<T>(threshold))
.template cast<T>();
dx.device(d) =
dout *
(((x + static_cast<T>(offset)) > static_cast<T>(0)).template cast<T>() *
(static_cast<T>(2) * x + static_cast<T>(offset)) /
static_cast<T>(scale) * tmp +
static_cast<T>(1) * (static_cast<T>(1) - tmp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
@ -1580,4 +1625,5 @@ class SqrtDoubleGradKernel
HardSigmoidGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \
ThresholdedReluGradFunctor);
ThresholdedReluGradFunctor); \
__macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);

@ -213,6 +213,7 @@ __all__ = [
'deformable_roi_pooling',
'var_conv_2d',
'shard_index',
'hard_swish',
]
kIgnoreIndex = -100
@ -13100,3 +13101,38 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
},
stop_gradient=True)
return out
@templatedoc()
def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
"""
${comment}
Args:
x(Varaible): Input of HardSwish operator.
threshold(float): The threshold parameter of HardSwish operator. Default:threshold=6.0
scale(float): The scale parameter of HardSwish operator. Default:scale=6.0
offset(float): The offset parameter of HardSwish operator. Default:offset=3.0
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The output tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.hard_swish(x)
"""
helper = LayerHelper('hard_swish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='hard_swish',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold,
'scale': scale,
'offset': offset})
return out

@ -450,6 +450,30 @@ class TestRelu6(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestHardSwish(TestActivation):
def setUp(self):
self.op_type = 'hard_swish'
self.init_dtype()
x = np.random.uniform(-6, 6, [4, 4]).astype(self.dtype)
threshold = 6.0
scale = 6.0
offset = 3.0
#the same with TestAbs
x[np.abs(x + offset) < 0.005] = 0.02
x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02
out = x * np.minimum(np.maximum(x + offset, 0), threshold) / scale
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSoftRelu(TestActivation):
def setUp(self):
self.op_type = "soft_relu"
@ -773,6 +797,7 @@ create_test_act_fp16_class(TestSoftsign)
create_test_act_fp16_class(TestThresholdedRelu)
create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestHardSwish)
if __name__ == "__main__":
unittest.main()

@ -903,6 +903,20 @@ class TestLayer(LayerTest):
with self.assertRaises(TypeError):
layers.eye(num_rows=3, batch_shape=[-1])
def test_hard_swish(self):
with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32')
ret = layers.hard_swish(t)
static_ret = self.get_static_graph_result(
feed={'t': np.ones(
[3, 3], dtype='float32')}, fetch_list=[ret])[0]
with self.dynamic_graph():
t = np.ones([3, 3], dtype='float32')
dy_ret = layers.hard_swish(base.to_variable(t))
self.assertTrue(np.allclose(static_ret, dy_ret.numpy()))
class TestBook(LayerTest):
def test_all_layers(self):

Loading…
Cancel
Save