|
|
|
@ -17,9 +17,9 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class MeanOp : public OperatorWithKernel {
|
|
|
|
|
class MeanOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr,
|
|
|
|
@ -28,9 +28,9 @@ class MeanOp : public OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MeanOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input of mean op");
|
|
|
|
|
AddOutput("Out", "The output of mean op").IgnoreGradient();
|
|
|
|
@ -38,9 +38,9 @@ class MeanOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MeanGradOp : public OperatorWithKernel {
|
|
|
|
|
class MeanGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
|
|
|
|
|
->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
}
|
|
|
|
@ -49,7 +49,10 @@ class MeanGradOp : public OperatorWithKernel {
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean,
|
|
|
|
|
ops::MeanKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean_grad,
|
|
|
|
|
ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|