complex gradient matmul (#29966)

* dot op support complex types

* matmul support complex types

* add test case

* matmul broadcast gradient support complex

* move conjFunctor to complex_functor.h
revert-31562-mean
chentianyu03 4 years ago committed by GitHub
parent b0bd93de00
commit e012930aa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,49 +17,13 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
template <typename DeviceContext, typename T>
class ConjKernel : public framework::OpKernel<T> {
public:
@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
ConjFunctor<T> functor(x_data, numel, out_data);
math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor);
}
};

@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL(
dot, ops::DotKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -17,12 +17,17 @@
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(dot_grad,
ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);

File diff suppressed because it is too large Load Diff

@ -135,6 +135,43 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_;
};
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
} // namespace math
} // namespace operators
} // namespace paddle

@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#ifdef __NVCC__
@ -468,6 +469,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename DeviceContext, typename T>
struct ConjHelper {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
dst.set_layout(src.layout());
dst.ShareDataWith(src);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex64> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex64>();
auto* dst_data = dst.mutable_data<paddle::platform::complex64>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex64)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex64> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex128> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex128>();
auto* dst_data = dst.mutable_data<paddle::platform::complex128>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex128)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex128> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
@ -519,6 +575,8 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor y_conj(y.type());
framework::Tensor x_conj(y.type());
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
@ -537,7 +595,7 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dout.numel() == 1) {
DotGradFunction<DeviceContext, T>(&x, &y, &dout, dx, dy, ctx);
DotGradFunction<DeviceContext, T>()(&x, &y, &dout, dx, dy, ctx);
return;
}
}
@ -562,6 +620,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(y, y_conj);
}
framework::DDim dy_dims;
@ -570,19 +632,23 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
}
if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x, true, false, dy);
CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(ctx, y, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x, false, false, dout, false, true, dy);
CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x, false, true, dy);
CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy);
} else {
CalcInputGrad(ctx, dout, false, false, y, true, false, dx);
CalcInputGrad(ctx, x, true, true, dout, false, true, dy);
CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx);
CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy);
}
if (dx) {
@ -602,40 +668,44 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
}
}

@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase):
paddle.dot(x1, y1).numpy(), np.array([[17], [58]])))
class TestComplexDotOp(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.y = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(1, self.dtype) + 1J * np.ones(1, self.dtype)
self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.y = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1)
def init_grad_input_output(self):
self.grad_out = np.ones((2, 1), self.dtype) + 1J * np.ones(
(2, 1), self.dtype)
self.grad_x = self._get_grad(self.grad_out, self.y)
self.grad_y = self._get_grad(self.grad_out, self.x)
def _get_grad(self, grad_out, input):
grad = np.empty((0, input.shape[1]))
for i in range(grad_out.shape[0]):
grad = np.append(grad, [grad_out[i] * np.conj(input[i])], axis=0)
return grad
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()

@ -405,5 +405,126 @@ class TestMatMulV2API(unittest.TestCase):
result = paddle.matmul(x, y)
class TestComplexMatMulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.y = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones(
(10, 10), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexMatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 2, 5)).astype(self.dtype) + 1J * np.random.random(
(10, 2, 5)).astype(self.dtype)
self.y = np.random.random(
(5, 20)).astype(self.dtype) + 1J * np.random.random(
(5, 20)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 2, 20), self.dtype) + 1J * np.ones(
(10, 2, 20), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.sum(np.matmul(
np.conj(self.x).transpose(0, 2, 1), self.grad_out),
axis=0)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == "__main__":
paddle.enable_static()
unittest.main()

@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [
'lstmp',
'margin_rank_loss',
'matmul',
'matmul_v2',
'mul',
'multiplex',
'rank_loss',

Loading…
Cancel
Save