|
|
@ -33,13 +33,23 @@ public:
|
|
|
|
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("X", "The input of mean op");
|
|
|
|
AddInput("X", "The input of mean op");
|
|
|
|
AddOutput("Out", "The output of mean op");
|
|
|
|
AddOutput("Out", "The output of mean op").IgnoreGradient();
|
|
|
|
AddComment("Mean Operator");
|
|
|
|
AddComment("Mean Operator");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeanGradOp : public OperatorWithKernel {
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
|
|
|
ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX())
|
|
|
|
|
|
|
|
->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
|
|
|
|
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
|
|
|
|
|
|
|
|
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
|
|
|
|