Fix/adam float64 (#10407)

* "optimizer op support float64"

* "fix ci"

* "fix ftrl op"
scopeFix
dzhwinter 8 years ago committed by GitHub
parent 6418c42148
commit a28dffbb0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel { class AdadeltaOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {

@ -23,6 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdagradOp : public framework::OperatorWithKernel { class AdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdamOp : public framework::OperatorWithKernel { class AdamOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdamOpMaker : public framework::OpProtoAndCheckerMaker { class AdamOpMaker : public framework::OpProtoAndCheckerMaker {

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdamaxOp : public framework::OperatorWithKernel { class AdamaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
ctx->SetOutputDim("InfNormOut", param_dims); ctx->SetOutputDim("InfNormOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class DecayedAdagradOp : public framework::OperatorWithKernel { class DecayedAdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class FTRLOp : public framework::OperatorWithKernel { class FTRLOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("SquaredAccumOut", param_dim);
ctx->SetOutputDim("LinearAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class ProximalAdagradOp : public framework::OperatorWithKernel { class ProximalAdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {

@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class ProximalGDOp : public framework::OperatorWithKernel { class ProximalGDOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {

Loading…
Cancel
Save