|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/mean_op.h"
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
|
|
|
|
|
protected:
|
|
|
|
|
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
|
|
|
|
|
const override {
|
|
|
|
|
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MeanGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type());
|
|
|
|
|
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MeanGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
|
|
|
|
|
ops::MeanGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|