|
|
|
@ -144,7 +144,7 @@ struct AbsFunctor<T, NoComplex<T, Real<T>>> {
|
|
|
|
|
: input_(input), output_(output), numel_(numel) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) const {
|
|
|
|
|
output_[idx] = abs(input_[idx]);
|
|
|
|
|
output_[idx] = std::abs(input_[idx]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* input_;
|
|
|
|
@ -162,7 +162,7 @@ struct AbsGradFunctor {
|
|
|
|
|
if (x_[idx] == T(0)) {
|
|
|
|
|
output_[idx] = T(0);
|
|
|
|
|
} else {
|
|
|
|
|
output_[idx] = T(dout_[idx]) * (x_[idx] / T(abs(x_[idx])));
|
|
|
|
|
output_[idx] = T(dout_[idx]) * (x_[idx] / T(std::abs(x_[idx])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -172,6 +172,48 @@ struct AbsGradFunctor {
|
|
|
|
|
int64_t numel_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct AbsGradFunctor<paddle::platform::complex64> {
|
|
|
|
|
AbsGradFunctor(const float* dout, const paddle::platform::complex64* x,
|
|
|
|
|
paddle::platform::complex64* output, int64_t numel)
|
|
|
|
|
: dout_(dout), x_(x), output_(output), numel_(numel) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) const {
|
|
|
|
|
if (x_[idx] == paddle::platform::complex64(0)) {
|
|
|
|
|
output_[idx] = paddle::platform::complex64(0);
|
|
|
|
|
} else {
|
|
|
|
|
output_[idx] = paddle::platform::complex64(dout_[idx]) *
|
|
|
|
|
(x_[idx] / paddle::platform::complex64(abs(x_[idx])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float* dout_;
|
|
|
|
|
const paddle::platform::complex64* x_;
|
|
|
|
|
paddle::platform::complex64* output_;
|
|
|
|
|
int64_t numel_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct AbsGradFunctor<paddle::platform::complex128> {
|
|
|
|
|
AbsGradFunctor(const double* dout, const paddle::platform::complex128* x,
|
|
|
|
|
paddle::platform::complex128* output, int64_t numel)
|
|
|
|
|
: dout_(dout), x_(x), output_(output), numel_(numel) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) const {
|
|
|
|
|
if (x_[idx] == paddle::platform::complex128(0)) {
|
|
|
|
|
output_[idx] = paddle::platform::complex128(0);
|
|
|
|
|
} else {
|
|
|
|
|
output_[idx] = paddle::platform::complex128(dout_[idx]) *
|
|
|
|
|
(x_[idx] / paddle::platform::complex128(abs(x_[idx])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const double* dout_;
|
|
|
|
|
const paddle::platform::complex128* x_;
|
|
|
|
|
paddle::platform::complex128* output_;
|
|
|
|
|
int64_t numel_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct AbsGradGradFunctor {
|
|
|
|
|
AbsGradGradFunctor(const T* ddx, const T* x, T* output, int64_t numel)
|
|
|
|
@ -181,7 +223,7 @@ struct AbsGradGradFunctor {
|
|
|
|
|
if (x_[idx] == T(0)) {
|
|
|
|
|
output_[idx] = T(0);
|
|
|
|
|
} else {
|
|
|
|
|
output_[idx] = T(ddx_[idx]) * x_[idx] / T(abs(x_[idx]));
|
|
|
|
|
output_[idx] = T(ddx_[idx]) * x_[idx] / T(std::abs(x_[idx]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -191,6 +233,49 @@ struct AbsGradGradFunctor {
|
|
|
|
|
int64_t numel_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct AbsGradGradFunctor<paddle::platform::complex128> {
|
|
|
|
|
AbsGradGradFunctor(const paddle::platform::complex128* ddx,
|
|
|
|
|
const paddle::platform::complex128* x,
|
|
|
|
|
paddle::platform::complex128* output, int64_t numel)
|
|
|
|
|
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) const {
|
|
|
|
|
if (x_[idx] == paddle::platform::complex128(0)) {
|
|
|
|
|
output_[idx] = paddle::platform::complex128(0);
|
|
|
|
|
} else {
|
|
|
|
|
output_[idx] = paddle::platform::complex128(ddx_[idx]) * x_[idx] /
|
|
|
|
|
paddle::platform::complex128(abs(x_[idx]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const paddle::platform::complex128* ddx_;
|
|
|
|
|
const paddle::platform::complex128* x_;
|
|
|
|
|
paddle::platform::complex128* output_;
|
|
|
|
|
int64_t numel_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct AbsGradGradFunctor<paddle::platform::complex64> {
|
|
|
|
|
AbsGradGradFunctor(const paddle::platform::complex64* ddx,
|
|
|
|
|
const paddle::platform::complex64* x,
|
|
|
|
|
paddle::platform::complex64* output, int64_t numel)
|
|
|
|
|
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) const {
|
|
|
|
|
if (x_[idx] == paddle::platform::complex64(0)) {
|
|
|
|
|
output_[idx] = paddle::platform::complex64(0);
|
|
|
|
|
} else {
|
|
|
|
|
output_[idx] = paddle::platform::complex64(ddx_[idx]) * x_[idx] /
|
|
|
|
|
paddle::platform::complex64(abs(x_[idx]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const paddle::platform::complex64* ddx_;
|
|
|
|
|
const paddle::platform::complex64* x_;
|
|
|
|
|
paddle::platform::complex64* output_;
|
|
|
|
|
int64_t numel_;
|
|
|
|
|
};
|
|
|
|
|
template <typename T, typename Enable = void>
|
|
|
|
|
struct RealToComplexFunctor;
|
|
|
|
|
|
|
|
|
|