fix leaky_relu op when alpha is zero, test=develop (#19833)

expand_as_op_1
Zeng Jinle 6 years ago committed by GitHub
parent 9cbc1eff2d
commit cabb9501bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = auto temp1 =
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>(); static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
auto temp2 = (out >= static_cast<T>(0)).template cast<T>(); auto temp2 = (out > static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
@ -1418,10 +1418,10 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddout.device(*d) = ddx *
ddx * ((out > static_cast<T>(0)).template cast<T>() +
((out >= static_cast<T>(0)).template cast<T>() + static_cast<T>(alpha) *
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>()) (out <= static_cast<T>(0)).template cast<T>())
.template cast<T>(); .template cast<T>();
} }
} }

@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_cpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02)); TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02));
} }
TEST(leaky_relu_grad_grad, test_cpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.0));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_gpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15)); TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15));
} }
TEST(leaky_relu_grad_grad, test_gpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.0));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -46,7 +46,7 @@ struct LeakyReluGradGradEachElementFunctor {
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {} : ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
HOSTDEVICE void operator()(int idx) { HOSTDEVICE void operator()(int idx) {
if (out_[idx] >= 0) { if (out_[idx] > 0) {
ddout_[idx] = ddx_[idx]; ddout_[idx] = ddx_[idx];
} else { } else {
ddout_[idx] = ddx_[idx] * alpha_; ddout_[idx] = ddx_[idx] * alpha_;

Loading…
Cancel
Save