|
|
|
@ -26,6 +26,9 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using complex64 = paddle::platform::complex64;
|
|
|
|
|
using complex128 = paddle::platform::complex128;
|
|
|
|
|
|
|
|
|
|
// Process an element in the output, used with a parallel-for
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct KronElemFunctor {
|
|
|
|
@ -172,6 +175,128 @@ struct KronGradElemFunctor {
|
|
|
|
|
const int ndims_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct KronGradElemFunctor<complex64> {
|
|
|
|
|
KronGradElemFunctor(const complex64* dout, const complex64* A,
|
|
|
|
|
const complex64* B, complex64* dout_a, complex64* dout_b,
|
|
|
|
|
const int64_t* stride_dout, const int64_t* stride_a,
|
|
|
|
|
const int64_t* stride_b, const int64_t* shape_b,
|
|
|
|
|
const int64_t numel_a, const int64_t numel_b,
|
|
|
|
|
const int ndims)
|
|
|
|
|
: dout_(dout),
|
|
|
|
|
A_(A),
|
|
|
|
|
B_(B),
|
|
|
|
|
dout_a_(dout_a),
|
|
|
|
|
dout_b_(dout_b),
|
|
|
|
|
stride_dout_(stride_dout),
|
|
|
|
|
stride_a_(stride_a),
|
|
|
|
|
stride_b_(stride_b),
|
|
|
|
|
shape_b_(shape_b),
|
|
|
|
|
numel_a_(numel_a),
|
|
|
|
|
numel_b_(numel_b),
|
|
|
|
|
ndims_(ndims) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) {
|
|
|
|
|
int64_t index = idx;
|
|
|
|
|
int64_t index_a = 0;
|
|
|
|
|
int64_t index_b = 0;
|
|
|
|
|
for (int i = 0; i < ndims_; i++) {
|
|
|
|
|
auto pos_i = index / stride_dout_[i];
|
|
|
|
|
index = index % stride_dout_[i];
|
|
|
|
|
auto pos_ai = pos_i / shape_b_[i];
|
|
|
|
|
auto pos_bi = pos_i % shape_b_[i];
|
|
|
|
|
index_a += stride_a_[i] * pos_ai;
|
|
|
|
|
index_b += stride_b_[i] * pos_bi;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dout_a_) {
|
|
|
|
|
size_t index_out_a = index_a * numel_b_ + index_b;
|
|
|
|
|
dout_a_[index_out_a] =
|
|
|
|
|
dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag);
|
|
|
|
|
}
|
|
|
|
|
if (dout_b_) {
|
|
|
|
|
size_t index_out_b = index_b * numel_a_ + index_a;
|
|
|
|
|
dout_b_[index_out_b] =
|
|
|
|
|
dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const complex64* dout_;
|
|
|
|
|
const complex64* A_;
|
|
|
|
|
const complex64* B_;
|
|
|
|
|
complex64* dout_a_;
|
|
|
|
|
complex64* dout_b_;
|
|
|
|
|
const int64_t* stride_dout_;
|
|
|
|
|
const int64_t* stride_a_;
|
|
|
|
|
const int64_t* stride_b_;
|
|
|
|
|
const int64_t* shape_b_;
|
|
|
|
|
const int64_t numel_a_;
|
|
|
|
|
const int64_t numel_b_;
|
|
|
|
|
const int ndims_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct KronGradElemFunctor<complex128> {
|
|
|
|
|
KronGradElemFunctor(const complex128* dout, const complex128* A,
|
|
|
|
|
const complex128* B, complex128* dout_a,
|
|
|
|
|
complex128* dout_b, const int64_t* stride_dout,
|
|
|
|
|
const int64_t* stride_a, const int64_t* stride_b,
|
|
|
|
|
const int64_t* shape_b, const int64_t numel_a,
|
|
|
|
|
const int64_t numel_b, const int ndims)
|
|
|
|
|
: dout_(dout),
|
|
|
|
|
A_(A),
|
|
|
|
|
B_(B),
|
|
|
|
|
dout_a_(dout_a),
|
|
|
|
|
dout_b_(dout_b),
|
|
|
|
|
stride_dout_(stride_dout),
|
|
|
|
|
stride_a_(stride_a),
|
|
|
|
|
stride_b_(stride_b),
|
|
|
|
|
shape_b_(shape_b),
|
|
|
|
|
numel_a_(numel_a),
|
|
|
|
|
numel_b_(numel_b),
|
|
|
|
|
ndims_(ndims) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) {
|
|
|
|
|
int64_t index = idx;
|
|
|
|
|
int64_t index_a = 0;
|
|
|
|
|
int64_t index_b = 0;
|
|
|
|
|
for (int i = 0; i < ndims_; i++) {
|
|
|
|
|
auto pos_i = index / stride_dout_[i];
|
|
|
|
|
index = index % stride_dout_[i];
|
|
|
|
|
auto pos_ai = pos_i / shape_b_[i];
|
|
|
|
|
auto pos_bi = pos_i % shape_b_[i];
|
|
|
|
|
index_a += stride_a_[i] * pos_ai;
|
|
|
|
|
index_b += stride_b_[i] * pos_bi;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dout_a_) {
|
|
|
|
|
size_t index_out_a = index_a * numel_b_ + index_b;
|
|
|
|
|
dout_a_[index_out_a] =
|
|
|
|
|
dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag);
|
|
|
|
|
}
|
|
|
|
|
if (dout_b_) {
|
|
|
|
|
size_t index_out_b = index_b * numel_a_ + index_a;
|
|
|
|
|
dout_b_[index_out_b] =
|
|
|
|
|
dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const complex128* dout_;
|
|
|
|
|
const complex128* A_;
|
|
|
|
|
const complex128* B_;
|
|
|
|
|
complex128* dout_a_;
|
|
|
|
|
complex128* dout_b_;
|
|
|
|
|
const int64_t* stride_dout_;
|
|
|
|
|
const int64_t* stride_a_;
|
|
|
|
|
const int64_t* stride_b_;
|
|
|
|
|
const int64_t* shape_b_;
|
|
|
|
|
const int64_t numel_a_;
|
|
|
|
|
const int64_t numel_b_;
|
|
|
|
|
const int ndims_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct IdentityFunctor {
|
|
|
|
|
HOSTDEVICE explicit inline IdentityFunctor() {}
|
|
|
|
|