Fix the param of swish (#27824)

my_2.0rc
hong19860320 4 years ago committed by GitHub
parent 070ac9590c
commit f3e2580cf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2175,7 +2175,7 @@ class TestSwish(TestActivation):
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
out = ref_swish(x)
self.inputs = {'X': x}
self.attrs = {'slope': 1.0}
self.attrs = {'beta': 1.0}
self.outputs = {'Out': out}
def test_check_grad(self):

@ -1183,7 +1183,7 @@ def swish(x, name=None):
"""
if in_dygraph_mode():
return core.ops.swish(x, 'slop', 1.0)
return core.ops.swish(x, 'beta', 1.0)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')
helper = LayerHelper('swish', **locals())
@ -1192,7 +1192,7 @@ def swish(x, name=None):
type='swish',
inputs={'X': x},
outputs={'Out': out},
attrs={'slope': 1.0})
attrs={'beta': 1.0})
return out

Loading…
Cancel
Save