|
|
@ -73,6 +73,27 @@ struct DivGradDX {
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
struct DivGradDX<paddle::platform::complex64> {
|
|
|
|
|
|
|
|
HOSTDEVICE paddle::platform::complex64 operator()(
|
|
|
|
|
|
|
|
paddle::platform::complex64 x, paddle::platform::complex64 y,
|
|
|
|
|
|
|
|
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
|
|
|
|
|
|
|
|
paddle::platform::complex64 y_conj(y.real, -y.imag);
|
|
|
|
|
|
|
|
return dout / y_conj;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
struct DivGradDX<paddle::platform::complex128> {
|
|
|
|
|
|
|
|
HOSTDEVICE paddle::platform::complex128 operator()(
|
|
|
|
|
|
|
|
paddle::platform::complex128 x, paddle::platform::complex128 y,
|
|
|
|
|
|
|
|
paddle::platform::complex128 out,
|
|
|
|
|
|
|
|
paddle::platform::complex128 dout) const {
|
|
|
|
|
|
|
|
paddle::platform::complex128 y_conj(y.real, -y.imag);
|
|
|
|
|
|
|
|
return dout / y_conj;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
struct DivGradDY {
|
|
|
|
struct DivGradDY {
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
|
|
|
@ -80,6 +101,28 @@ struct DivGradDY {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
struct DivGradDY<paddle::platform::complex64> {
|
|
|
|
|
|
|
|
HOSTDEVICE paddle::platform::complex64 operator()(
|
|
|
|
|
|
|
|
paddle::platform::complex64 x, paddle::platform::complex64 y,
|
|
|
|
|
|
|
|
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
|
|
|
|
|
|
|
|
paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag);
|
|
|
|
|
|
|
|
return -dout * out_div_y_conj;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
struct DivGradDY<paddle::platform::complex128> {
|
|
|
|
|
|
|
|
HOSTDEVICE paddle::platform::complex128 operator()(
|
|
|
|
|
|
|
|
paddle::platform::complex128 x, paddle::platform::complex128 y,
|
|
|
|
|
|
|
|
paddle::platform::complex128 out,
|
|
|
|
|
|
|
|
paddle::platform::complex128 dout) const {
|
|
|
|
|
|
|
|
paddle::platform::complex128 out_div_y_conj((out / y).real,
|
|
|
|
|
|
|
|
-(out / y).imag);
|
|
|
|
|
|
|
|
return -dout * out_div_y_conj;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
struct DivDoubleDY {
|
|
|
|
struct DivDoubleDY {
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
|
|
|
|