@ -22,6 +22,7 @@ limitations under the License. */
# include "paddle/fluid/framework/op_registry.h"
# include "paddle/fluid/operators/dot_op.h"
# include "paddle/fluid/operators/math/blas.h"
# include "paddle/fluid/operators/math/complex_functors.h"
# include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
# ifdef __NVCC__
@ -468,6 +469,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
ReshapeTensorIntoMatrixSequence ( y , mat_dim_y ) ;
}
template < typename DeviceContext , typename T >
struct ConjHelper {
explicit ConjHelper ( const framework : : ExecutionContext & ctx ) : ctx_ ( ctx ) { }
HOSTDEVICE void operator ( ) ( framework : : Tensor & src , framework : : Tensor & dst ) {
dst . Resize ( src . dims ( ) ) ;
dst . set_layout ( src . layout ( ) ) ;
dst . ShareDataWith ( src ) ;
return ;
}
const framework : : ExecutionContext & ctx_ ;
} ;
template < typename DeviceContext >
struct ConjHelper < DeviceContext , paddle : : platform : : complex64 > {
explicit ConjHelper ( const framework : : ExecutionContext & ctx ) : ctx_ ( ctx ) { }
HOSTDEVICE void operator ( ) ( framework : : Tensor & src , framework : : Tensor & dst ) {
dst . Resize ( src . dims ( ) ) ;
auto * src_data = src . data < paddle : : platform : : complex64 > ( ) ;
auto * dst_data = dst . mutable_data < paddle : : platform : : complex64 > (
ctx_ . GetPlace ( ) ,
size_t ( src . numel ( ) * sizeof ( paddle : : platform : : complex64 ) ) ) ;
platform : : ForRange < DeviceContext > for_range (
ctx_ . template device_context < DeviceContext > ( ) , src . numel ( ) ) ;
math : : ConjFunctor < paddle : : platform : : complex64 > functor (
src_data , src . numel ( ) , dst_data ) ;
for_range ( functor ) ;
return ;
}
const framework : : ExecutionContext & ctx_ ;
} ;
template < typename DeviceContext >
struct ConjHelper < DeviceContext , paddle : : platform : : complex128 > {
explicit ConjHelper ( const framework : : ExecutionContext & ctx ) : ctx_ ( ctx ) { }
HOSTDEVICE void operator ( ) ( framework : : Tensor & src , framework : : Tensor & dst ) {
dst . Resize ( src . dims ( ) ) ;
auto * src_data = src . data < paddle : : platform : : complex128 > ( ) ;
auto * dst_data = dst . mutable_data < paddle : : platform : : complex128 > (
ctx_ . GetPlace ( ) ,
size_t ( src . numel ( ) * sizeof ( paddle : : platform : : complex128 ) ) ) ;
platform : : ForRange < DeviceContext > for_range (
ctx_ . template device_context < DeviceContext > ( ) , src . numel ( ) ) ;
math : : ConjFunctor < paddle : : platform : : complex128 > functor (
src_data , src . numel ( ) , dst_data ) ;
for_range ( functor ) ;
return ;
}
const framework : : ExecutionContext & ctx_ ;
} ;
template < typename DeviceContext , typename T >
class MatMulV2GradKernel : public framework : : OpKernel < T > {
public :
@ -519,6 +575,8 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto x = * ctx . Input < framework : : Tensor > ( " X " ) ;
auto y = * ctx . Input < framework : : Tensor > ( " Y " ) ;
auto dout = * ctx . Input < framework : : Tensor > ( framework : : GradVarName ( " Out " ) ) ;
framework : : Tensor y_conj ( y . type ( ) ) ;
framework : : Tensor x_conj ( y . type ( ) ) ;
// get dims
std : : vector < std : : int64_t > x_dims = vectorize ( x . dims ( ) ) ;
@ -537,7 +595,7 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if ( dx ) dx - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
if ( dy ) dy - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
if ( dout . numel ( ) = = 1 ) {
DotGradFunction < DeviceContext , T > ( & x , & y , & dout , dx , dy , ctx ) ;
DotGradFunction < DeviceContext , T > ( ) ( & x , & y , & dout , dx , dy , ctx ) ;
return ;
}
}
@ -562,6 +620,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if ( dx_dims ! = x . dims ( ) ) {
dx - > Resize ( x . dims ( ) ) ;
}
// for complex
ConjHelper < DeviceContext , T > conj_helper ( ctx ) ;
conj_helper ( y , y_conj ) ;
}
framework : : DDim dy_dims ;
@ -570,19 +632,23 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if ( dy_dims ! = y . dims ( ) ) {
dy - > Resize ( y . dims ( ) ) ;
}
// for complex
ConjHelper < DeviceContext , T > conj_helper ( ctx ) ;
conj_helper ( x , x_conj ) ;
}
if ( transpose_x & & transpose_y ) {
CalcInputGrad ( ctx , y , true , true , dout , true , false , dx ) ;
CalcInputGrad ( ctx , dout , true , true , x , true , false , dy ) ;
CalcInputGrad ( ctx , y _conj , true , true , dout , true , false , dx ) ;
CalcInputGrad ( ctx , dout , true , true , x _conj , true , false , dy ) ;
} else if ( transpose_x ) {
CalcInputGrad ( ctx , y , false , false , dout , true , false , dx ) ;
CalcInputGrad ( ctx , x , false , false , dout , false , true , dy ) ;
CalcInputGrad ( ctx , y _conj , false , false , dout , true , false , dx ) ;
CalcInputGrad ( ctx , x _conj , false , false , dout , false , true , dy ) ;
} else if ( transpose_y ) {
CalcInputGrad ( ctx , dout , false , false , y , false , true , dx ) ;
CalcInputGrad ( ctx , dout , true , true , x , false , true , dy ) ;
CalcInputGrad ( ctx , dout , false , false , y _conj , false , true , dx ) ;
CalcInputGrad ( ctx , dout , true , true , x _conj , false , true , dy ) ;
} else {
CalcInputGrad ( ctx , dout , false , false , y , true , false , dx ) ;
CalcInputGrad ( ctx , x , true , true , dout , false , true , dy ) ;
CalcInputGrad ( ctx , dout , false , false , y _conj , true , false , dx ) ;
CalcInputGrad ( ctx , x _conj , true , true , dout , false , true , dy ) ;
}
if ( dx ) {
@ -602,40 +668,44 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
VLOG ( 3 ) < < " It need cost much time to reduce sum for the broadcast and "
" wastes the memory. So we should avoid the case in reality " ;
Tensor dx_help , dy_help ;
ConjHelper < DeviceContext , T > conj_helper ( ctx ) ;
conj_helper ( x , x_conj ) ;
conj_helper ( y , y_conj ) ;
if ( transpose_x ) {
if ( transpose_y ) {
// X'Y': dA = Y'G', dB = G'X'
if ( dx )
MatMulFunction < DeviceContext , T > ( & y , & dout , y_dims , dout_dims ,
MatMulFunction < DeviceContext , T > ( & y _conj , & dout , y_dims , dout_dims ,
& dx_help , true , true , ctx ) ;
if ( dy )
MatMulFunction < DeviceContext , T > ( & dout , & x , dout_dims , x_dims ,
MatMulFunction < DeviceContext , T > ( & dout , & x _conj , dout_dims , x_dims ,
& dy_help , true , true , ctx ) ;
} else {
// X'Y: dX = YG', dY = XG
if ( dx )
MatMulFunction < DeviceContext , T > ( & y , & dout , y_dims , dout_dims ,
MatMulFunction < DeviceContext , T > ( & y _conj , & dout , y_dims , dout_dims ,
& dx_help , false , true , ctx ) ;
if ( dy )
MatMulFunction < DeviceContext , T > ( & x , & dout , x_dims , dout_dims ,
MatMulFunction < DeviceContext , T > ( & x _conj , & dout , x_dims , dout_dims ,
& dy_help , false , false , ctx ) ;
}
} else {
if ( transpose_y ) {
// XY': dX = GY, dY = G'X
if ( dx )
MatMulFunction < DeviceContext , T > ( & dout , & y , dout_dims , y_dims ,
MatMulFunction < DeviceContext , T > ( & dout , & y _conj , dout_dims , y_dims ,
& dx_help , false , false , ctx ) ;
if ( dy )
MatMulFunction < DeviceContext , T > ( & dout , & x , dout_dims , x_dims ,
MatMulFunction < DeviceContext , T > ( & dout , & x _conj , dout_dims , x_dims ,
& dy_help , true , false , ctx ) ;
} else {
// XY: dX = GY', dY = X'G
if ( dx )
MatMulFunction < DeviceContext , T > ( & dout , & y , dout_dims , y_dims ,
MatMulFunction < DeviceContext , T > ( & dout , & y _conj , dout_dims , y_dims ,
& dx_help , false , true , ctx ) ;
if ( dy )
MatMulFunction < DeviceContext , T > ( & x , & dout , x_dims , dout_dims ,
MatMulFunction < DeviceContext , T > ( & x _conj , & dout , x_dims , dout_dims ,
& dy_help , true , false , ctx ) ;
}
}