@ -483,6 +483,11 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public :
void Make ( ) override {
AddInput ( " X " , " Input of Pow operator " ) ;
AddInput ( " FactorTensor " ,
" (Tensor<float>, optional). If provided, pow will use this "
" The shape of FactorTensor MUST BE [1]. "
" it has higher priority than attr(factor). " )
. AsDispensable ( ) ;
AddOutput ( " Out " , " Output of Pow operator " ) ;
AddAttr < float > ( " factor " , " The exponential factor of Pow " ) . SetDefault ( 1.0f ) ;
AddComment ( R " DOC(
@ -778,6 +783,75 @@ DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
{ framework : : GradVarName ( " Out " ) ,
framework : : GradVarName ( " X " ) } ) ;
class PowGradOpDescMaker : public framework : : SingleGradOpDescMaker {
public :
using framework : : SingleGradOpDescMaker : : SingleGradOpDescMaker ;
protected :
std : : unique_ptr < framework : : OpDesc > Apply ( ) const override {
std : : unique_ptr < framework : : OpDesc > op ( new framework : : OpDesc ( ) ) ;
op - > SetType ( " pow_grad " ) ;
op - > SetInput ( " X " , Input ( " X " ) ) ;
op - > SetInput ( framework : : GradVarName ( " Out " ) , OutputGrad ( " Out " ) ) ;
op - > SetOutput ( framework : : GradVarName ( " X " ) , InputGrad ( " X " ) ) ;
op - > SetInput ( " FactorTensor " , Input ( " FactorTensor " ) ) ;
op - > SetAttrMap ( Attrs ( ) ) ;
return op ;
}
} ;
class PowOp : public framework : : OperatorWithKernel {
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
ctx - > ShareDim ( " X " , /*->*/ " Out " ) ;
ctx - > ShareLoD ( " X " , /*->*/ " Out " ) ;
}
protected :
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
return GetKernelType ( ctx , * this , " X " ) ;
}
framework : : OpKernelType GetKernelTypeForVar (
const std : : string & var_name , const Tensor & tensor ,
const framework : : OpKernelType & expected_kernel_type ) const override {
if ( var_name = = " FactorTensor " ) {
return expected_kernel_type ;
}
return framework : : OpKernelType ( expected_kernel_type . data_type_ ,
tensor . place ( ) , tensor . layout ( ) ) ;
}
} ;
class PowOpGrad : public framework : : OperatorWithKernel {
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
auto out_grad_name = framework : : GradVarName ( " Out " ) ;
ctx - > ShareDim ( out_grad_name , framework : : GradVarName ( " X " ) ) ;
ctx - > ShareLoD ( out_grad_name , framework : : GradVarName ( " X " ) ) ;
}
protected :
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
return GetKernelType ( ctx , * this , framework : : GradVarName ( " Out " ) ) ;
}
framework : : OpKernelType GetKernelTypeForVar (
const std : : string & var_name , const Tensor & tensor ,
const framework : : OpKernelType & expected_kernel_type ) const override {
if ( var_name = = " FactorTensor " ) {
return expected_kernel_type ;
}
return framework : : OpKernelType ( expected_kernel_type . data_type_ ,
tensor . place ( ) , tensor . layout ( ) ) ;
}
} ;
} // namespace operators
} // namespace paddle
@ -907,3 +981,22 @@ REGISTER_OP_CPU_KERNEL(
ops : : SquareDoubleGradKernel < plat : : CPUDeviceContext ,
ops : : SquareGradGradFunctor < plat : : float16 > > ) ;
/* ========================================================================== */
/* ========================== pow register ============================ */
REGISTER_OPERATOR (
pow , ops : : PowOp , ops : : PowOpMaker , ops : : ActivationOpInferVarType ,
ops : : PowGradOpDescMaker ,
std : : conditional < ops : : CanInplaceAct < ops : : PowGradFunctor < float > > ( ) ,
: : paddle : : framework : : SingleOpInplaceInToOut , void > : : type ) ;
REGISTER_OPERATOR ( pow_grad , ops : : PowOpGrad ,
ops : : ActivationGradOpInplaceInference ) ;
REGISTER_OP_CPU_KERNEL (
pow , ops : : PowKernel < plat : : CPUDeviceContext , ops : : PowFunctor < float > > ,
ops : : PowKernel < plat : : CPUDeviceContext , ops : : PowFunctor < double > > ) ;
REGISTER_OP_CPU_KERNEL (
pow_grad ,
ops : : PowGradKernel < plat : : CPUDeviceContext , ops : : PowGradFunctor < float > > ,
ops : : PowGradKernel < plat : : CPUDeviceContext , ops : : PowGradFunctor < double > > ) ;
/* ========================================================================== */