parent
							
								
									d81084939b
								
							
						
					
					
						commit
						e3b27d1998
					
				@ -0,0 +1,61 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/operators/sgd_op.h"
 | 
				
			||||
#include "paddle/framework/op_registry.h"
 | 
				
			||||
#include "paddle/framework/tensor.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
class SGDOp : public framework::OperatorWithKernel {
 | 
				
			||||
protected:
 | 
				
			||||
  void InferShape(
 | 
				
			||||
      const std::vector<const framework::Tensor *> &inputs,
 | 
				
			||||
      const std::vector<framework::Tensor *> &outputs) const override {
 | 
				
			||||
    PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
 | 
				
			||||
    PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
 | 
				
			||||
    PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
 | 
				
			||||
    PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set");
 | 
				
			||||
    PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
 | 
				
			||||
    PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
 | 
				
			||||
                   "Two input of SGD Op's dimension must be same.");
 | 
				
			||||
    outputs[0]->set_dims(inputs[0]->dims());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
public:
 | 
				
			||||
  SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
 | 
				
			||||
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
 | 
				
			||||
    AddInput("param", "input parameter");
 | 
				
			||||
    AddInput("grad", "input gradient");
 | 
				
			||||
    AddOutput("param_out", "output parameter");
 | 
				
			||||
    AddAttr<float>("learning_rate", "learning rate of sgd");
 | 
				
			||||
    AddComment(R"DOC(
 | 
				
			||||
 | 
				
			||||
Simplest sgd algorithm.
 | 
				
			||||
 | 
				
			||||
param_out = param - learning_rate * grad;
 | 
				
			||||
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker);
 | 
				
			||||
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float>
 | 
				
			||||
    SGDOpKernel_CPU_float;
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
 | 
				
			||||
@ -0,0 +1,5 @@
 | 
				
			||||
#include "paddle/operators/sgd_op.h"
 | 
				
			||||
#include "paddle/framework/op_registry.h"
 | 
				
			||||
 | 
				
			||||
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float;
 | 
				
			||||
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float);
 | 
				
			||||
@ -0,0 +1,39 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
#include "glog/logging.h"
 | 
				
			||||
#include "paddle/framework/operator.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
template <typename Place, typename T>
 | 
				
			||||
class SGDOpKernel : public framework::OpKernel {
 | 
				
			||||
public:
 | 
				
			||||
  void Compute(const framework::KernelContext& ctx) const override {
 | 
				
			||||
    auto param = ctx.Input("param")->Get<framework::Tensor>();
 | 
				
			||||
    auto grad = ctx.Input("grad")->Get<framework::Tensor>();
 | 
				
			||||
    auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>();
 | 
				
			||||
    float lr = ctx.op_.GetAttr<float>("learning_rate");
 | 
				
			||||
 | 
				
			||||
    param_out->mutable_data<T>(ctx.GetPlace());
 | 
				
			||||
 | 
				
			||||
    param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) =
 | 
				
			||||
        param.flat<T>() - lr * grad.flat<T>();
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,22 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include <gtest/gtest.h>
 | 
				
			||||
#include <paddle/framework/op_registry.h>
 | 
				
			||||
USE_OP(sgd);
 | 
				
			||||
TEST(SGDOp, GetOpProto) {
 | 
				
			||||
  auto& protos = paddle::framework::OpRegistry::protos();
 | 
				
			||||
  auto it = protos.find("sgd");
 | 
				
			||||
  ASSERT_NE(it, protos.end());
 | 
				
			||||
}
 | 
				
			||||
@ -1,2 +1,2 @@
 | 
				
			||||
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
 | 
				
			||||
        add_op fc_op)
 | 
				
			||||
        add_op fc_op sgd_op)
 | 
				
			||||
 | 
				
			||||
@ -1,3 +1,3 @@
 | 
				
			||||
add_python_test(test_framework test_protobuf.py test_scope.py
 | 
				
			||||
    test_default_scope_funcs.py test_op_creation_methods.py
 | 
				
			||||
    test_tensor.py test_fc_op.py test_add_two_op.py)
 | 
				
			||||
    test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py)
 | 
				
			||||
 | 
				
			||||
@ -0,0 +1,18 @@
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy
 | 
				
			||||
from op_test_util import OpTestMeta
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestSGD(unittest.TestCase):
 | 
				
			||||
    __metaclass__ = OpTestMeta
 | 
				
			||||
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.type = "sgd"
 | 
				
			||||
        self.param = numpy.random.random((342, 345)).astype("float32")
 | 
				
			||||
        self.grad = numpy.random.random((342, 345)).astype("float32")
 | 
				
			||||
        self.learning_rate = 0.1
 | 
				
			||||
        self.param_out = self.param - self.learning_rate * self.grad
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == "__main__":
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue