* "fix double type error"

* "fix ci"

* "softmax fp64"

* "fix momentum"

* "fix ci"
trainerSaveLoadParams
dzhwinter 7 years ago committed by GitHub
parent 1ae086ed19
commit f63ff90b03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,6 +17,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {

@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
@ -47,9 +46,9 @@ Scale operator
$$Out = scale*X$$
)DOC");
AddAttr<AttrType>("scale",
"(float, default 1.0)"
"The scaling factor of the scale operator.")
AddAttr<float>("scale",
"(float, default 1.0)"
"The scaling factor of the scale operator.")
.SetDefault(1.0);
}
};
@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradMaker);
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,

@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>);
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
softmax_grad,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);

@ -19,6 +19,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxKernel<plat::CUDADeviceContext, double>,
ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>);

@ -75,4 +75,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(top_k,
ops::TopkKernel<paddle::platform::CPUPlace, float>);
ops::TopkKernel<paddle::platform::CPUPlace, float>,
ops::TopkKernel<paddle::platform::CPUPlace, double>);

@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>);

Loading…
Cancel
Save