fix leaky_relu op when alpha is zero, test=develop (#19833)

expand_as_op_1
Zeng Jinle 6 years ago committed by GitHub
parent 9cbc1eff2d
commit cabb9501bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 =
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>();
auto temp2 = (out >= static_cast<T>(0)).template cast<T>();
static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
auto temp2 = (out > static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
@ -1418,10 +1418,10 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) =
ddx *
((out >= static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>())
ddout.device(*d) = ddx *
((out > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) *
(out <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}

@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_cpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02));
}
TEST(leaky_relu_grad_grad, test_cpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.0));
}
} // namespace operators
} // namespace paddle

@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_gpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15));
}
TEST(leaky_relu_grad_grad, test_gpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.0));
}
} // namespace operators
} // namespace paddle

@ -46,7 +46,7 @@ struct LeakyReluGradGradEachElementFunctor {
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
HOSTDEVICE void operator()(int idx) {
if (out_[idx] >= 0) {
if (out_[idx] > 0) {
ddout_[idx] = ddx_[idx];
} else {
ddout_[idx] = ddx_[idx] * alpha_;

Loading…
Cancel
Save