From 666e6651320b7eb12fe69d048e293fe0448d6387 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 5 Jan 2021 17:40:11 +0800 Subject: [PATCH] change the kron gradient when complex types (#29995) --- paddle/fluid/operators/kron_op.h | 125 ++++++++++++++++++ .../fluid/tests/unittests/test_kron_op.py | 85 ++++++++++++ 2 files changed, 210 insertions(+) diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index 62762f3f04..2af3716ae4 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -26,6 +26,9 @@ limitations under the License. */ namespace paddle { namespace operators { +using complex64 = paddle::platform::complex64; +using complex128 = paddle::platform::complex128; + // Process an element in the output, used with a parallel-for template struct KronElemFunctor { @@ -172,6 +175,128 @@ struct KronGradElemFunctor { const int ndims_; }; +template <> +struct KronGradElemFunctor { + KronGradElemFunctor(const complex64* dout, const complex64* A, + const complex64* B, complex64* dout_a, complex64* dout_b, + const int64_t* stride_dout, const int64_t* stride_a, + const int64_t* stride_b, const int64_t* shape_b, + const int64_t numel_a, const int64_t numel_b, + const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = + dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag); + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = + dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag); + } + } + + private: + const complex64* dout_; + const complex64* A_; + const complex64* B_; + complex64* dout_a_; + complex64* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + +template <> +struct KronGradElemFunctor { + KronGradElemFunctor(const complex128* dout, const complex128* A, + const complex128* B, complex128* dout_a, + complex128* dout_b, const int64_t* stride_dout, + const int64_t* stride_a, const int64_t* stride_b, + const int64_t* shape_b, const int64_t numel_a, + const int64_t numel_b, const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = + dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag); + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = + dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag); + } + } + + private: + const complex128* dout_; + const complex128* A_; + const complex128* B_; + complex128* dout_a_; + complex128* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index 68ad35489c..634739596e 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -102,5 +102,90 @@ class TestKronLayer(unittest.TestCase): np.testing.assert_allclose(c, np.kron(a, b)) +class TestComplexKronOp(OpTest): + def setUp(self): + self.op_type = "kron" + self.x_shape = np.array([10, 10]) + self.y_shape = np.array([3, 35]) + self.out_shape = self.x_shape * self.y_shape + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(self.x_shape).astype( + self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype( + self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype) + self.out = np.kron(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones( + self.out_shape, self.dtype) + self.grad_x = self.get_grad_x_by_numpy() + self.grad_y = self.get_grad_y_by_numpy() + + def get_grad_x_by_numpy(self): + grad_x = np.zeros(self.x_shape, np.complex) + for x_i in range(self.x_shape[0]): + for x_j in range(self.x_shape[1]): + for i in range(self.y_shape[0]): + for j in range(self.y_shape[1]): + idx_i = x_i * self.y_shape[0] + i + idx_j = x_j * self.y_shape[1] + j + grad_x[x_i][x_j] += self.grad_out[idx_i][ + idx_j] * np.conj(self.y[i][j]) + return grad_x + + def get_grad_y_by_numpy(self): + grad_y = np.zeros(self.y_shape, np.complex) + for y_i in range(self.y_shape[0]): + for y_j in range(self.y_shape[1]): + for x_i in range(self.x_shape[0]): + for x_j in range(self.x_shape[1]): + idx_i = x_i * self.y_shape[0] + y_i + idx_j = x_j * self.y_shape[1] + y_j + grad_y[y_i][y_j] += self.grad_out[idx_i][ + idx_j] * np.conj(self.x[x_i][x_j]) + return grad_y + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main()