From 61c03f9d59969d698faca1547115d80e8faa3e9d Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Mon, 2 Oct 2017 17:52:56 -0700 Subject: [PATCH 1/5] Adding the implementation for rmsprop operator --- paddle/operators/rmsprop_op.cc | 87 +++++++++++++++++++ paddle/operators/rmsprop_op.cu | 20 +++++ paddle/operators/rmsprop_op.h | 54 ++++++++++++ .../v2/framework/tests/test_rmsprop_op.py | 37 ++++++++ 4 files changed, 198 insertions(+) create mode 100644 paddle/operators/rmsprop_op.cc create mode 100644 paddle/operators/rmsprop_op.cu create mode 100644 paddle/operators/rmsprop_op.h create mode 100644 python/paddle/v2/framework/tests/test_rmsprop_op.py diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc new file mode 100644 index 0000000000..dcf3599f4d --- /dev/null +++ b/paddle/operators/rmsprop_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/rmsprop_op.h" + +namespace paddle { +namespace operators { + +class RmspropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(param) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(grad) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment"), + "Input(moment) of RmspropOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(param_out) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), + "Output(moment_out) of RmspropOp should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and grad input of RmspropOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Moment"), + "Param and moment input of RmspropOp should have the same dimension."); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("MomentOut", param_dim); + } +}; + +class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RmspropOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "Input parameter"); + AddInput("Grad", "Input gradient"); + AddInput("Moment", "Second moment"); + + AddOutput("ParamOut", "Output parameter"); + AddOutput("MomentOut", "Output second moment"); + + AddAttr("learningRate", "Learning rate"); + AddAttr("epsilon", "Constant for numerical stability"); + AddAttr("decayRate", "Decay rate for moving average of gradients"); + AddComment(R"DOC( + +RMSprop + +MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad +ParamOut = Param - learningRate * Grad / (sqrt(MomentOut) + epsilon) + +The original slide(Slide 29 of +http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) +does not have the epsilon attribute. It is added here for numerical stability +to avoid division by zero. + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); +REGISTER_OP_CPU_KERNEL(rmsprop, + ops::RmspropOpKernel); diff --git a/paddle/operators/rmsprop_op.cu b/paddle/operators/rmsprop_op.cu new file mode 100644 index 0000000000..52634a5481 --- /dev/null +++ b/paddle/operators/rmsprop_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/rmsprop_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(rmsprop, + ops::RmspropOpKernel); diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h new file mode 100644 index 0000000000..c94c24bddd --- /dev/null +++ b/paddle/operators/rmsprop_op.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class RmspropOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto moment_out = ctx.Output("MomentOut"); + + param_out->mutable_data(ctx.GetPlace()); + moment_out->mutable_data(ctx.GetPlace()); + + float lr = ctx.Attr("learningRate"); + float epsilon = ctx.Attr("epsilon"); + float decay = ctx.Attr("decayRate"); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto m = EigenVector::Flatten(*ctx.Input("Moment")); + auto p_out = EigenVector::Flatten(*param_out); + auto m_out = EigenVector::Flatten(*moment_out); + auto place = ctx.GetEigenDevice(); + + m_out.device(place) = decay * m + (1 - decay) * g * g; + p_out.device(place) = p - lr * g / (m_out.sqrt() + epsilon); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py new file mode 100644 index 0000000000..1fc59a0f11 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestRmspropOp(OpTest): + def setUp(self): + self.op_type = "rmsprop" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + + learning_rate = 0.01 + epsilon = 1e-6 + decay_rate = 0.9 + + self.inputs = {'Param': param, 'Grad': grad, 'Moment': moment} + + self.attrs = { + 'learningRate': learning_rate, + 'epsilon': epsilon, + 'decayRate': decay_rate + } + + moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad + param_out = param - learning_rate * grad / (np.sqrt(moment_out) + + epsilon) + + self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From 163d28714349d51596be1cb165f93be2b8290bda Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Mon, 2 Oct 2017 19:23:05 -0700 Subject: [PATCH 2/5] Made learning rate the input --- paddle/operators/rmsprop_op.cc | 16 +++++++++++----- paddle/operators/rmsprop_op.h | 2 +- .../paddle/v2/framework/tests/test_rmsprop_op.py | 15 ++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index dcf3599f4d..602efab3db 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -24,11 +24,13 @@ class RmspropOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(param) of RmspropOp should not be null."); + "Input(Param) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(grad) of RmspropOp should not be null."); + "Input(Grad) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Moment"), - "Input(moment) of RmspropOp should not be null."); + "Input(Moment) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(param_out) of RmspropOp should not be null."); @@ -43,6 +45,10 @@ class RmspropOp : public framework::OperatorWithKernel { param_dim, ctx->GetInputDim("Moment"), "Param and moment input of RmspropOp should have the same dimension."); + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); } @@ -56,11 +62,11 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Param", "Input parameter"); AddInput("Grad", "Input gradient"); AddInput("Moment", "Second moment"); + AddInput("LearningRate", "Learning Rate"); AddOutput("ParamOut", "Output parameter"); AddOutput("MomentOut", "Output second moment"); - AddAttr("learningRate", "Learning rate"); AddAttr("epsilon", "Constant for numerical stability"); AddAttr("decayRate", "Decay rate for moving average of gradients"); AddComment(R"DOC( @@ -68,7 +74,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { RMSprop MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad -ParamOut = Param - learningRate * Grad / (sqrt(MomentOut) + epsilon) +ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon) The original slide(Slide 29 of http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index c94c24bddd..65b9edd35b 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -34,13 +34,13 @@ class RmspropOpKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); - float lr = ctx.Attr("learningRate"); float epsilon = ctx.Attr("epsilon"); float decay = ctx.Attr("decayRate"); auto p = EigenVector::Flatten(*ctx.Input("Param")); auto g = EigenVector::Flatten(*ctx.Input("Grad")); auto m = EigenVector::Flatten(*ctx.Input("Moment")); + float lr = ctx.Input("LearningRate")->data()[0]; auto p_out = EigenVector::Flatten(*param_out); auto m_out = EigenVector::Flatten(*moment_out); auto place = ctx.GetEigenDevice(); diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 1fc59a0f11..64ca5da48e 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -10,19 +10,20 @@ class TestRmspropOp(OpTest): param = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") - learning_rate = 0.01 epsilon = 1e-6 decay_rate = 0.9 - self.inputs = {'Param': param, 'Grad': grad, 'Moment': moment} - - self.attrs = { - 'learningRate': learning_rate, - 'epsilon': epsilon, - 'decayRate': decay_rate + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'LearningRate': learning_rate } + self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate} + moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad param_out = param - learning_rate * grad / (np.sqrt(moment_out) + epsilon) From 94855f4af08002253fca10aab4bffc187e5c982f Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Wed, 4 Oct 2017 12:45:56 -0700 Subject: [PATCH 3/5] Fixed changes proposed in the review --- paddle/operators/rmsprop_op.cc | 69 +++++++++++++------ paddle/operators/rmsprop_op.h | 19 +++-- .../v2/framework/tests/test_rmsprop_op.py | 24 ++++--- 3 files changed, 77 insertions(+), 35 deletions(-) diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index 602efab3db..1e06e08ede 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContextBase *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("MeanSquare"), + "Input(MeanSquare) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), "Input(Grad) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Moment"), "Input(Moment) of RmspropOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(param_out) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), - "Output(moment_out) of RmspropOp should not be null."); + "Output(Momentum_out) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"), + "Output(MeanSquareOut) of RmspropOp should not be null."); auto param_dim = ctx->GetInputDim("Param"); PADDLE_ENFORCE_EQ( param_dim, ctx->GetInputDim("Grad"), "Param and grad input of RmspropOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Moment"), - "Param and moment input of RmspropOp should have the same dimension."); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"), + "Param and Momentum input of RmspropOp " + "should have the same dimension."); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"), + "Param and Momentum input of RmspropOp " + "should have the same dimension."); auto lr_dim = ctx->GetInputDim("LearningRate"); PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, @@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); + ctx->SetOutputDim("MeanSquareOut", param_dim); } }; @@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { RmspropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Param", "Input parameter"); - AddInput("Grad", "Input gradient"); - AddInput("Moment", "Second moment"); - AddInput("LearningRate", "Learning Rate"); - - AddOutput("ParamOut", "Output parameter"); - AddOutput("MomentOut", "Output second moment"); - - AddAttr("epsilon", "Constant for numerical stability"); - AddAttr("decayRate", "Decay rate for moving average of gradients"); + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated"); + AddInput("MeanSquare", + "(Tensor, default Tensor)" + " The mean square value that gets updated"); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1"); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter"); + AddInput("Moment", + "(Tensor, default Tensor) The moment that gets updated"); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value"); + AddOutput("MomentOut", "(Tensor) Output updated moment"); + AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value"); + + AddAttr("epsilon", + "(float, default 1e-10) Constant " + "for numerical stability.") + .SetDefault(1e-10); + AddAttr("decay", + "(float, default 0.9) " + "Discounting factor for coming gradient.") + .SetDefault(0.9); + AddAttr("momentum", "(float, default 0.0) Constant value") + .SetDefault(0.0); AddComment(R"DOC( RMSprop -MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad -ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon) +MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad +MomentOut = momentum * Moment + + LearningRate * Grad / sqrt(MeanSquareOut + epsilon) +ParamOut = Param - MomentOut -The original slide(Slide 29 of +The original slides that proposed RMSprop: Slide 29 of http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -does not have the epsilon attribute. It is added here for numerical stability -to avoid division by zero. )DOC"); } diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index 65b9edd35b..ed4b283ce4 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto param_out = ctx.Output("ParamOut"); auto moment_out = ctx.Output("MomentOut"); + auto mean_square_out = ctx.Output("MeanSquareOut"); param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); + mean_square_out->mutable_data(ctx.GetPlace()); float epsilon = ctx.Attr("epsilon"); - float decay = ctx.Attr("decayRate"); + float rho = ctx.Attr("decay"); + float momentum = ctx.Attr("momentum"); auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto g = EigenVector::Flatten(*ctx.Input("Grad")); - auto m = EigenVector::Flatten(*ctx.Input("Moment")); + auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); float lr = ctx.Input("LearningRate")->data()[0]; + auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto mom = EigenVector::Flatten(*ctx.Input("Moment")); + auto p_out = EigenVector::Flatten(*param_out); - auto m_out = EigenVector::Flatten(*moment_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); auto place = ctx.GetEigenDevice(); - m_out.device(place) = decay * m + (1 - decay) * g * g; - p_out.device(place) = p - lr * g / (m_out.sqrt() + epsilon); + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt(); + p_out.device(place) = p - mom_out; } }; diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 64ca5da48e..84bd815c8c 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -8,27 +8,35 @@ class TestRmspropOp(OpTest): self.op_type = "rmsprop" param = np.random.random((123, 321)).astype("float32") + mean_square = np.random.random((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") grad = np.random.random((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.01]).astype("float32") epsilon = 1e-6 - decay_rate = 0.9 + decay = 0.9 + momentum = 0.0 self.inputs = { 'Param': param, + 'MeanSquare': mean_square, + 'LearningRate': learning_rate, 'Grad': grad, 'Moment': moment, - 'LearningRate': learning_rate } - self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate} + self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum} - moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad - param_out = param - learning_rate * grad / (np.sqrt(moment_out) + - epsilon) + ms_out = decay * mean_square + (1 - decay) * grad * grad + moment_out = momentum * moment + \ + learning_rate * grad / np.sqrt(ms_out + epsilon) + param_out = param - moment_out - self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'MeanSquareOut': ms_out + } def test_check_output(self): self.check_output() From fa12e51675dbbc77eef75e5c346f2deecd45b0dc Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Wed, 4 Oct 2017 13:27:40 -0700 Subject: [PATCH 4/5] Adding the default attribute test case --- paddle/operators/rmsprop_op.cc | 6 +-- paddle/operators/rmsprop_op.h | 6 +-- .../v2/framework/tests/test_rmsprop_op.py | 45 ++++++++++++++++++- 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index 1e06e08ede..8f61c7fdda 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -89,13 +89,13 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("epsilon", "(float, default 1e-10) Constant " "for numerical stability.") - .SetDefault(1e-10); + .SetDefault(1.0e-10f); AddAttr("decay", "(float, default 0.9) " "Discounting factor for coming gradient.") - .SetDefault(0.9); + .SetDefault(0.9f); AddAttr("momentum", "(float, default 0.0) Constant value") - .SetDefault(0.0); + .SetDefault(0.0f); AddComment(R"DOC( RMSprop diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index ed4b283ce4..9c04276ec6 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -28,9 +28,9 @@ template class RmspropOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto moment_out = ctx.Output("MomentOut"); - auto mean_square_out = ctx.Output("MeanSquareOut"); + auto* param_out = ctx.Output("ParamOut"); + auto* moment_out = ctx.Output("MomentOut"); + auto* mean_square_out = ctx.Output("MeanSquareOut"); param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 84bd815c8c..3e5ff733e9 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -3,7 +3,10 @@ import numpy as np from op_test import OpTest -class TestRmspropOp(OpTest): +class TestRmspropOp1(OpTest): + ''' Test RMSProp with explicit inputs + ''' + def setUp(self): self.op_type = "rmsprop" @@ -42,5 +45,45 @@ class TestRmspropOp(OpTest): self.check_output() +class TestRmspropOp2(OpTest): + '''Test RMSProp with defaukt values for attributes + ''' + + def setUp(self): + self.op_type = "rmsprop" + + param = np.random.random((123, 321)).astype("float32") + mean_square = np.random.random((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + + epsilon = 1.0e-10 + decay = 0.9 + momentum = 0.0 + + self.inputs = { + 'Param': param, + 'MeanSquare': mean_square, + 'LearningRate': learning_rate, + 'Grad': grad, + 'Moment': moment, + } + + ms_out = decay * mean_square + (1 - decay) * grad * grad + moment_out = momentum * moment + \ + learning_rate * grad / np.sqrt(ms_out + epsilon) + param_out = param - moment_out + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'MeanSquareOut': ms_out + } + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main() From f52cdaa0cee682ddc3588286af42d960141596f0 Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Thu, 5 Oct 2017 19:12:27 -0700 Subject: [PATCH 5/5] Updated RMSProp to have learning rate as an input and work with GPU --- paddle/operators/rmsprop_op.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index 9c04276ec6..7bf2129010 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -32,6 +32,8 @@ class RmspropOpKernel : public framework::OpKernel { auto* moment_out = ctx.Output("MomentOut"); auto* mean_square_out = ctx.Output("MeanSquareOut"); + auto grad = ctx.Input("Grad"); + param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); mean_square_out->mutable_data(ctx.GetPlace()); @@ -42,8 +44,8 @@ class RmspropOpKernel : public framework::OpKernel { auto p = EigenVector::Flatten(*ctx.Input("Param")); auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - float lr = ctx.Input("LearningRate")->data()[0]; - auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + auto g = EigenVector::Flatten(*grad); auto mom = EigenVector::Flatten(*ctx.Input("Moment")); auto p_out = EigenVector::Flatten(*param_out); @@ -51,8 +53,12 @@ class RmspropOpKernel : public framework::OpKernel { auto ms_out = EigenVector::Flatten(*mean_square_out); auto place = ctx.GetEigenDevice(); + Eigen::DSizes grad_dsize(grad->numel()); + ms_out.device(place) = rho * ms + (1 - rho) * g * g; - mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt(); + mom_out.device(place) = + momentum * mom + + lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); p_out.device(place) = p - mom_out; } };