fixed LeakyReLU

pull/1961/head
jiangjinsheng 5 years ago
parent 09318086aa
commit 6c92282e5e

@ -249,11 +249,11 @@ class LeakyReLU(Cell):
self.alpha = alpha
def construct(self, x):
alpha = P.Cast()(F.scalar_to_array(self.alpha), P.DType()(x))
if alpha <= 1:
out = P.Maximum()(alpha * x, x)
alpha_array = P.Cast()(F.scalar_to_array(self.alpha), P.DType()(x))
if self.alpha <= 1:
out = P.Maximum()(alpha_array * x, x)
else:
out = P.Minimum()(alpha * x, x)
out = P.Minimum()(alpha_array * x, x)
return out

Loading…
Cancel
Save