|
|
|
@ -153,8 +153,8 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CVMNoNeedBufferVarInference, "CVM");
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CVMGradNoNeedBufferVarInference, "X");
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CVMNoNeedBufferVarInferer, "CVM");
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CVMGradNoNeedBufferVarInferer, "X");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
@ -163,10 +163,10 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker,
|
|
|
|
|
ops::CVMGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::CVMGradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::CVMNoNeedBufferVarInference);
|
|
|
|
|
ops::CVMNoNeedBufferVarInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp,
|
|
|
|
|
ops::CVMGradNoNeedBufferVarInference);
|
|
|
|
|
ops::CVMGradNoNeedBufferVarInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>);
|
|
|
|
|
|
|
|
|
|