|
|
@ -61,6 +61,8 @@ class ReduceMeanDoubleGradMaker : public framework::GradOpDescMakerBase {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ReduceMeanGradNoNeedBufferVarInference,
|
|
|
|
|
|
|
|
"X");
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
@ -73,7 +75,8 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
|
|
|
|
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
|
|
|
|
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
|
|
|
|
ops::ReduceMeanOpGradDescMaker);
|
|
|
|
ops::ReduceMeanOpGradDescMaker);
|
|
|
|
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
|
|
|
|
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
|
|
|
|
ops::ReduceMeanDoubleGradMaker);
|
|
|
|
ops::ReduceMeanDoubleGradMaker,
|
|
|
|
|
|
|
|
ops::ReduceMeanGradNoNeedBufferVarInference);
|
|
|
|
REGISTER_OP_CPU_KERNEL(reduce_mean,
|
|
|
|
REGISTER_OP_CPU_KERNEL(reduce_mean,
|
|
|
|
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
float, ops::MeanFunctor>,
|
|
|
|
float, ops::MeanFunctor>,
|
|
|
@ -83,12 +86,13 @@ REGISTER_OP_CPU_KERNEL(reduce_mean,
|
|
|
|
int, ops::MeanFunctor>,
|
|
|
|
int, ops::MeanFunctor>,
|
|
|
|
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
int64_t, ops::MeanFunctor>);
|
|
|
|
int64_t, ops::MeanFunctor>);
|
|
|
|
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
|
|
|
|
|
|
|
|
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
template <typename T>
|
|
|
|
float, ops::MeanGradFunctor>,
|
|
|
|
using CPUReduceMeanGradKernel =
|
|
|
|
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T,
|
|
|
|
double, ops::MeanGradFunctor>,
|
|
|
|
ops::MeanGradFunctor, true>;
|
|
|
|
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
|
|
|
int, ops::MeanGradFunctor>,
|
|
|
|
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
|
|
|
|
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
CPUReduceMeanGradKernel<double>,
|
|
|
|
int64_t, ops::MeanGradFunctor>);
|
|
|
|
CPUReduceMeanGradKernel<int>,
|
|
|
|
|
|
|
|
CPUReduceMeanGradKernel<int64_t>);
|
|
|
|