|
|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
class ProximalAdagradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dim);
|
|
|
|
|
ctx->SetOutputDim("MomentOut", param_dim);
|
|
|
|
|
}
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
|