|
|
|
@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
"Input(param) of Momentum should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
@ -45,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(VelocityOut) of Momentum should not be null.");
|
|
|
|
|
|
|
|
|
|
auto param_dim = ctx->GetInputDim("Param");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dim, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of MomentumOp should have the same dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dim, ctx->GetInputDim("Velocity"),
|
|
|
|
|
"Param and Velocity of MomentumOp should have the same dimension.");
|
|
|
|
|
if (ctx->GetInputsVarType("Grad")[0] ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dim, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of MomentumOp should have the same dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dim, ctx->GetInputDim("Velocity"),
|
|
|
|
|
"Param and Velocity of MomentumOp should have the same dimension.");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
|
|
|
|
|
"Learning_rate should be a scalar");
|
|
|
|
|
|
|
|
|
@ -58,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("VelocityOut", param_dim);
|
|
|
|
|
}
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MomentumOpInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
|
auto input_var = op_desc.Input("Param")[0];
|
|
|
|
|
for (auto& out_var : op_desc.Output("ParamOut")) {
|
|
|
|
|
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
block->FindRecursiveOrCreateVar(out_var).SetType(
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS);
|
|
|
|
|
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
block->FindRecursiveOrCreateVar(out_var).SetType(
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
@ -115,6 +139,9 @@ $$
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
|
|
|
|
|
ops::MomentumOpKernel<double>);
|
|
|
|
|
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::MomentumOpInferVarType);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|