Merge pull request #4565 from kavyasrinet/rmsprop
	
		
	
				
					
				
			Adding the implementation for rmsprop operatorrevert-4814-Add_sequence_project_op
						commit
						48f98a6770
					
				| @ -0,0 +1,120 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/operators/rmsprop_op.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| class RmspropOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
| 
 | ||||
|  protected: | ||||
|   void InferShape(framework::InferShapeContextBase *ctx) const override { | ||||
|     PADDLE_ENFORCE(ctx->HasInput("Param"), | ||||
|                    "Input(Param) of RmspropOp should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("MeanSquare"), | ||||
|                    "Input(MeanSquare) of RmspropOp should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("LearningRate"), | ||||
|                    "Input(LearningRate) of RmspropOp should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("Grad"), | ||||
|                    "Input(Grad) of RmspropOp should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("Moment"), | ||||
|                    "Input(Moment) of RmspropOp should not be null."); | ||||
| 
 | ||||
|     PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), | ||||
|                    "Output(param_out) of RmspropOp should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), | ||||
|                    "Output(Momentum_out) of RmspropOp should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"), | ||||
|                    "Output(MeanSquareOut) of RmspropOp should not be null."); | ||||
| 
 | ||||
|     auto param_dim = ctx->GetInputDim("Param"); | ||||
|     PADDLE_ENFORCE_EQ( | ||||
|         param_dim, ctx->GetInputDim("Grad"), | ||||
|         "Param and grad input of RmspropOp should have the same dimension."); | ||||
|     PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"), | ||||
|                       "Param and Momentum input of RmspropOp " | ||||
|                       "should have the same dimension."); | ||||
|     PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"), | ||||
|                       "Param and Momentum input of RmspropOp " | ||||
|                       "should have the same dimension."); | ||||
| 
 | ||||
|     auto lr_dim = ctx->GetInputDim("LearningRate"); | ||||
|     PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, | ||||
|                       "Learning Rate should be a scalar."); | ||||
| 
 | ||||
|     ctx->SetOutputDim("ParamOut", param_dim); | ||||
|     ctx->SetOutputDim("MomentOut", param_dim); | ||||
|     ctx->SetOutputDim("MeanSquareOut", param_dim); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { | ||||
|  public: | ||||
|   RmspropOpMaker(framework::OpProto *proto, | ||||
|                  framework::OpAttrChecker *op_checker) | ||||
|       : OpProtoAndCheckerMaker(proto, op_checker) { | ||||
|     AddInput("Param", | ||||
|              "(Tensor, default Tensor<float>) " | ||||
|              "Input parameter value that has to be updated"); | ||||
|     AddInput("MeanSquare", | ||||
|              "(Tensor, default Tensor<float>)" | ||||
|              " The mean square value that gets updated"); | ||||
|     AddInput("LearningRate", | ||||
|              "(Tensor, default Tensor<float>) " | ||||
|              "The learning rate should be a tensor of size 1"); | ||||
|     AddInput("Grad", | ||||
|              "(Tensor, default Tensor<float>) " | ||||
|              "Input gradient of the parameter"); | ||||
|     AddInput("Moment", | ||||
|              "(Tensor, default Tensor<float>) The moment that gets updated"); | ||||
| 
 | ||||
|     AddOutput("ParamOut", "(Tensor) Output updated parameter value"); | ||||
|     AddOutput("MomentOut", "(Tensor) Output updated moment"); | ||||
|     AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value"); | ||||
| 
 | ||||
|     AddAttr<float>("epsilon", | ||||
|                    "(float, default 1e-10) Constant " | ||||
|                    "for numerical stability.") | ||||
|         .SetDefault(1.0e-10f); | ||||
|     AddAttr<float>("decay", | ||||
|                    "(float, default 0.9) " | ||||
|                    "Discounting factor for coming gradient.") | ||||
|         .SetDefault(0.9f); | ||||
|     AddAttr<float>("momentum", "(float, default 0.0) Constant value") | ||||
|         .SetDefault(0.0f); | ||||
|     AddComment(R"DOC( | ||||
| 
 | ||||
| RMSprop | ||||
| 
 | ||||
| MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad | ||||
| MomentOut = momentum * Moment + | ||||
|             LearningRate * Grad / sqrt(MeanSquareOut + epsilon) | ||||
| ParamOut = Param -  MomentOut | ||||
| 
 | ||||
| The original slides that proposed RMSprop: Slide 29 of | ||||
| http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
 | ||||
| 
 | ||||
| )DOC"); | ||||
|   } | ||||
| }; | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); | ||||
| REGISTER_OP_CPU_KERNEL(rmsprop, | ||||
|                        ops::RmspropOpKernel<paddle::platform::CPUPlace, float>); | ||||
| @ -0,0 +1,20 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||||
| 
 | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
| 
 | ||||
|    http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 
 | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. */ | ||||
| 
 | ||||
| #define EIGEN_USE_GPU | ||||
| #include "paddle/operators/rmsprop_op.h" | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OP_GPU_KERNEL(rmsprop, | ||||
|                        ops::RmspropOpKernel<paddle::platform::GPUPlace, float>); | ||||
| @ -0,0 +1,67 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| #include "paddle/framework/eigen.h" | ||||
| #include "paddle/framework/op_registry.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| using Tensor = framework::Tensor; | ||||
| template <typename T, int MajorType = Eigen::RowMajor, | ||||
|           typename IndexType = Eigen::DenseIndex> | ||||
| using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||||
| 
 | ||||
| template <typename Place, typename T> | ||||
| class RmspropOpKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     auto* param_out = ctx.Output<Tensor>("ParamOut"); | ||||
|     auto* moment_out = ctx.Output<Tensor>("MomentOut"); | ||||
|     auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut"); | ||||
| 
 | ||||
|     auto grad = ctx.Input<Tensor>("Grad"); | ||||
| 
 | ||||
|     param_out->mutable_data<T>(ctx.GetPlace()); | ||||
|     moment_out->mutable_data<T>(ctx.GetPlace()); | ||||
|     mean_square_out->mutable_data<T>(ctx.GetPlace()); | ||||
| 
 | ||||
|     float epsilon = ctx.Attr<float>("epsilon"); | ||||
|     float rho = ctx.Attr<float>("decay"); | ||||
|     float momentum = ctx.Attr<float>("momentum"); | ||||
| 
 | ||||
|     auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); | ||||
|     auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare")); | ||||
|     auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate")); | ||||
|     auto g = EigenVector<T>::Flatten(*grad); | ||||
|     auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment")); | ||||
| 
 | ||||
|     auto p_out = EigenVector<T>::Flatten(*param_out); | ||||
|     auto mom_out = EigenVector<T>::Flatten(*moment_out); | ||||
|     auto ms_out = EigenVector<T>::Flatten(*mean_square_out); | ||||
|     auto place = ctx.GetEigenDevice<Place>(); | ||||
| 
 | ||||
|     Eigen::DSizes<int, 1> grad_dsize(grad->numel()); | ||||
| 
 | ||||
|     ms_out.device(place) = rho * ms + (1 - rho) * g * g; | ||||
|     mom_out.device(place) = | ||||
|         momentum * mom + | ||||
|         lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); | ||||
|     p_out.device(place) = p - mom_out; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,89 @@ | ||||
| import unittest | ||||
| import numpy as np | ||||
| from op_test import OpTest | ||||
| 
 | ||||
| 
 | ||||
| class TestRmspropOp1(OpTest): | ||||
|     ''' Test RMSProp with explicit inputs | ||||
|     ''' | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.op_type = "rmsprop" | ||||
| 
 | ||||
|         param = np.random.random((123, 321)).astype("float32") | ||||
|         mean_square = np.random.random((123, 321)).astype("float32") | ||||
|         learning_rate = np.array([0.01]).astype("float32") | ||||
|         grad = np.random.random((123, 321)).astype("float32") | ||||
|         moment = np.zeros((123, 321)).astype("float32") | ||||
| 
 | ||||
|         epsilon = 1e-6 | ||||
|         decay = 0.9 | ||||
|         momentum = 0.0 | ||||
| 
 | ||||
|         self.inputs = { | ||||
|             'Param': param, | ||||
|             'MeanSquare': mean_square, | ||||
|             'LearningRate': learning_rate, | ||||
|             'Grad': grad, | ||||
|             'Moment': moment, | ||||
|         } | ||||
| 
 | ||||
|         self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum} | ||||
| 
 | ||||
|         ms_out = decay * mean_square + (1 - decay) * grad * grad | ||||
|         moment_out = momentum * moment + \ | ||||
|             learning_rate * grad / np.sqrt(ms_out + epsilon) | ||||
|         param_out = param - moment_out | ||||
| 
 | ||||
|         self.outputs = { | ||||
|             'ParamOut': param_out, | ||||
|             'MomentOut': moment_out, | ||||
|             'MeanSquareOut': ms_out | ||||
|         } | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
| 
 | ||||
| class TestRmspropOp2(OpTest): | ||||
|     '''Test RMSProp with defaukt values for attributes | ||||
|     ''' | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.op_type = "rmsprop" | ||||
| 
 | ||||
|         param = np.random.random((123, 321)).astype("float32") | ||||
|         mean_square = np.random.random((123, 321)).astype("float32") | ||||
|         learning_rate = np.array([0.01]).astype("float32") | ||||
|         grad = np.random.random((123, 321)).astype("float32") | ||||
|         moment = np.zeros((123, 321)).astype("float32") | ||||
| 
 | ||||
|         epsilon = 1.0e-10 | ||||
|         decay = 0.9 | ||||
|         momentum = 0.0 | ||||
| 
 | ||||
|         self.inputs = { | ||||
|             'Param': param, | ||||
|             'MeanSquare': mean_square, | ||||
|             'LearningRate': learning_rate, | ||||
|             'Grad': grad, | ||||
|             'Moment': moment, | ||||
|         } | ||||
| 
 | ||||
|         ms_out = decay * mean_square + (1 - decay) * grad * grad | ||||
|         moment_out = momentum * moment + \ | ||||
|             learning_rate * grad / np.sqrt(ms_out + epsilon) | ||||
|         param_out = param - moment_out | ||||
| 
 | ||||
|         self.outputs = { | ||||
|             'ParamOut': param_out, | ||||
|             'MomentOut': moment_out, | ||||
|             'MeanSquareOut': ms_out | ||||
|         } | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue