From 182ce51c6d73d98420aa91d998a328503eac538d Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 17 Oct 2017 14:48:40 -0700 Subject: [PATCH 1/3] add sparse kernel of sgd operator --- paddle/operators/sgd_op.cc | 40 ++++++++++++++++++++++--- paddle/operators/sgd_op.cu | 60 ++++++++++++++++++++++++++++++++++++++ paddle/operators/sgd_op.h | 47 ++++++++++++++++++++--------- 3 files changed, 130 insertions(+), 17 deletions(-) diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 0f78eeab9b..e26a1c7893 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, "Learning rate should have 1 element"); auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), - "Two input of SGD Op's dimension must be same."); + // TODO(qijun): check dimensions of Param and Grad at complie + // and run time. ctx->SetOutputDim("ParamOut", param_dim); } }; class SGDOpMaker : public framework::OpProtoAndCheckerMaker { public: - SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Param", "Input parameter"); AddInput("LearningRate", "Learning rate of SGD"); @@ -58,6 +58,38 @@ param_out = param - learning_rate * grad; )DOC"); } }; + +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& ctx, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output) { + auto in_height = input.height(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = input.value(); + auto& in_rows = input.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = output->data(); + auto* lr = learning_rate.data(); + + for (size_t i = 0; i < in_rows.size(); i++) { + for (int64_t j = 0; j < in_row_numel; j++) { + out_data[in_rows[i] * in_row_numel + j] -= + lr[0] * in_data[i * in_row_numel + j]; + } + } + } +}; + +template struct SparseSGDFunctor; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index f5ba6d3c29..5c28314141 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -14,6 +14,66 @@ #define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +namespace { +template +__global__ void SparseSGDFunctorKernel(const T* selected_rows, + const int64_t* rows, + const T* learning_rate, T* tensor_out, + int64_t row_numel, int block_size) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicSub(tensor_out + index, + learning_rate[0] * selected_rows[index]); + } +} +} // namespace + +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& ctx, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output) { + auto in_height = input.height(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = input.value(); + auto& in_rows = input.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = output->data(); + + int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in_rows.size()); + SparseSGDFunctorKernel< + T><<(context) + .stream()>>>(in_data, in_rows.data(), learning_rate.data(), + out_data, in_row_numel, block_size); + } +}; + +template struct SparseSGDFunctor; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sgd, diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 26f4012f25..a872d7f749 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -15,31 +15,52 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/selected_rows.h" namespace paddle { namespace operators { +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& ctx, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output); +}; + template class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param = ctx.Input("Param"); - auto grad = ctx.Input("Grad"); - auto param_out = ctx.Output("ParamOut"); - auto learning_rate = ctx.Input("LearningRate"); + auto* param = ctx.Input("Param"); + auto* param_out = ctx.Output("ParamOut"); + auto* learning_rate = ctx.Input("LearningRate"); - param_out->mutable_data(ctx.GetPlace()); + auto* grad_var = ctx.InputVar("Grad"); + if (grad_var->IsType()) { + param_out->mutable_data(ctx.GetPlace()); + auto* grad = ctx.Input("Grad"); - auto p = framework::EigenVector::Flatten(*param); - auto g = framework::EigenVector::Flatten(*grad); - auto o = framework::EigenVector::Flatten(*param_out); - auto lr = framework::EigenVector::Flatten(*learning_rate); - auto place = ctx.GetEigenDevice(); + auto p = framework::EigenVector::Flatten(*param); + auto g = framework::EigenVector::Flatten(*grad); + auto o = framework::EigenVector::Flatten(*param_out); + auto lr = framework::EigenVector::Flatten(*learning_rate); + auto place = ctx.GetEigenDevice(); - Eigen::DSizes grad_dsize(grad->numel()); - o.device(place) = p - lr.broadcast(grad_dsize) * g; + Eigen::DSizes grad_dsize(grad->numel()); + o.device(place) = p - lr.broadcast(grad_dsize) * g; + } else if (grad_var->IsType()) { + // TODO(qijun): In Sparse SGD operator, in-place update is enforced. + // This manual optimization brings difficulty to track data dependency. + // It's better to find a more elegant solution. + PADDLE_ENFORCE_EQ(param, param_out); + auto* grad = ctx.Input("Grad"); + SparseSGDFunctor functor; + functor(ctx.device_context(), *grad, *learning_rate, param_out); + } else { + PADDLE_THROW("Unsupported Variable Type of Grad"); + } } }; - } // namespace operators } // namespace paddle From ab8cc401e61dd49d393a72903a427ea6fa14bec7 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 17 Oct 2017 16:05:05 -0700 Subject: [PATCH 2/3] add sparse sgd operator unittest --- paddle/operators/sgd_op.h | 3 +- paddle/pybind/pybind.cc | 5 ++ .../v2/framework/tests/test_selected_rows.py | 23 +++---- .../paddle/v2/framework/tests/test_sgd_op.py | 60 +++++++++++++++++++ 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index a872d7f749..8c28d5e66b 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -37,7 +37,8 @@ class SGDOpKernel : public framework::OpKernel { auto* learning_rate = ctx.Input("LearningRate"); auto* grad_var = ctx.InputVar("Grad"); - if (grad_var->IsType()) { + // Actually, all tensors are LoDTensor except SelectedRows. + if (grad_var->IsType()) { param_out->mutable_data(ctx.GetPlace()); auto* grad = ctx.Input("Grad"); diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fcae92ad99..65e265b614 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -186,6 +186,11 @@ All parameter, weight, gradient are variables in Paddle. return self.GetMutable(); }, py::return_value_policy::reference) + .def("get_selected_rows", + [](Variable &self) -> SelectedRows * { + return self.GetMutable(); + }, + py::return_value_policy::reference) .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); diff --git a/python/paddle/v2/framework/tests/test_selected_rows.py b/python/paddle/v2/framework/tests/test_selected_rows.py index 661e818179..e8a930cb08 100644 --- a/python/paddle/v2/framework/tests/test_selected_rows.py +++ b/python/paddle/v2/framework/tests/test_selected_rows.py @@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase): place = core.CPUPlace() height = 10 rows = [0, 4, 7] - row_numel = 10 - selcted_rows = core.SelectedRows(rows, row_numel) - np_array = np.ones((len(rows), height)).astype("float32") + row_numel = 12 + selected_rows = core.SelectedRows(rows, height) + np_array = np.ones((len(rows), row_numel)).astype("float32") np_array[0, 0] = 2.0 np_array[2, 8] = 4.0 - tensor = selcted_rows.get_tensor() + tensor = selected_rows.get_tensor() tensor.set(np_array, place) # compare rows - self.assertEqual(0, selcted_rows.rows()[0]) - self.assertEqual(4, selcted_rows.rows()[1]) - self.assertEqual(7, selcted_rows.rows()[2]) + self.assertEqual(0, selected_rows.rows()[0]) + self.assertEqual(4, selected_rows.rows()[1]) + self.assertEqual(7, selected_rows.rows()[2]) # compare height - self.assertEqual(10, selcted_rows.height()) + self.assertEqual(10, selected_rows.height()) # compare tensor self.assertAlmostEqual(2.0, - selcted_rows.get_tensor().get_float_element(0)) + selected_rows.get_tensor().get_float_element(0)) self.assertAlmostEqual(1.0, - selcted_rows.get_tensor().get_float_element(1)) + selected_rows.get_tensor().get_float_element(1)) self.assertAlmostEqual( - 4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8)) + 4.0, + selected_rows.get_tensor().get_float_element(2 * row_numel + 8)) if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index 2dd881e5e1..c7d6a3b345 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -1,5 +1,7 @@ import unittest import numpy as np +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator from op_test import OpTest @@ -17,5 +19,63 @@ class TestSGDOp(OpTest): self.check_output() +class TestSparseSGDOp(unittest.TestCase): + def test_sparse_sgd(self): + scope = core.Scope() + + # create and initialize Grad Variable + place = core.CPUPlace() + height = 10 + rows = [0, 4, 7] + row_numel = 12 + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + # create and initialize Param Variable + param = scope.var('Param').get_tensor() + param_array = np.full((height, row_numel), 5.0).astype("float32") + param.set(param_array, place) + + # create and initialize LeraningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and run sgd operator + sgd_op = Operator( + "sgd", + Param='Param', + Grad='Grad', + ParamOut='Param', + LearningRate='LearningRate') + ctx = core.DeviceContext.create(place) + sgd_op.run(scope, ctx) + + # get and compare result + result_array = np.array(param) + + # rows[0] = 0, 5.0 - 2.0 * 2.0 + self.assertAlmostEqual(1.0, result_array[rows[0], 0]) + # rows[0] = 0, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[0], 2]) + # 5.0 - 2.0 * 0.0 + self.assertAlmostEqual(5.0, result_array[1, 0]) + # rows[1] = 4, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[1], 10]) + # 5.0 - 2.0 * 0.0 + self.assertAlmostEqual(5.0, result_array[5, 8]) + # rows[2] = 7, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[2], 1]) + # rows[2] = 7, 5.0 - 2.0 * 4.0 + self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) + + if __name__ == "__main__": unittest.main() From f9681459b2075e8067e6bda45a62967fc4baec62 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 17 Oct 2017 16:33:52 -0700 Subject: [PATCH 3/3] fix gpu build error --- paddle/operators/sgd_op.cc | 2 +- paddle/operators/sgd_op.cu | 6 +++--- paddle/operators/sgd_op.h | 2 +- paddle/pybind/pybind.cc | 10 +++++++++- python/paddle/v2/framework/tests/test_sgd_op.py | 11 +++++++++-- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index e26a1c7893..2acb96d1b4 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -61,7 +61,7 @@ param_out = param - learning_rate * grad; template struct SparseSGDFunctor { - void operator()(const platform::DeviceContext& ctx, + void operator()(const platform::DeviceContext& context, const framework::SelectedRows& input, const framework::Tensor& learning_rate, framework::Tensor* output) { diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index 5c28314141..106f9b746b 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -34,15 +34,15 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows, for (int index = tid; index < row_numel; index += block_size) { // Since index in rows of SelectedRows can be duplicate, we have to use // Atomic Operation to avoid concurrent write error. - paddle::platform::CudaAtomicSub(tensor_out + index, - learning_rate[0] * selected_rows[index]); + paddle::platform::CudaAtomicAdd( + tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]); } } } // namespace template struct SparseSGDFunctor { - void operator()(const platform::DeviceContext& ctx, + void operator()(const platform::DeviceContext& context, const framework::SelectedRows& input, const framework::Tensor& learning_rate, framework::Tensor* output) { diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 8c28d5e66b..78b595fc6c 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -22,7 +22,7 @@ namespace operators { template struct SparseSGDFunctor { - void operator()(const platform::DeviceContext& ctx, + void operator()(const platform::DeviceContext& context, const framework::SelectedRows& input, const framework::Tensor& learning_rate, framework::Tensor* output); diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 65e265b614..80854fb0c5 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -153,7 +153,15 @@ PYBIND11_PLUGIN(core) { py::return_value_policy::reference) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) - .def("set_rows", &SelectedRows::set_rows) + .def("set_rows", + [](SelectedRows &self, std::vector rows) { +#ifndef PADDLE_WITH_CUDA + self.set_rows(rows); +#else + Vector new_rows(rows); + self.set_rows(new_rows); +#endif + }) .def("rows", [](SelectedRows &self) { #ifndef PADDLE_WITH_CUDA return self.rows(); diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index c7d6a3b345..01262bba4d 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -20,11 +20,10 @@ class TestSGDOp(OpTest): class TestSparseSGDOp(unittest.TestCase): - def test_sparse_sgd(self): + def check_with_place(self, place): scope = core.Scope() # create and initialize Grad Variable - place = core.CPUPlace() height = 10 rows = [0, 4, 7] row_numel = 12 @@ -35,6 +34,7 @@ class TestSparseSGDOp(unittest.TestCase): np_array = np.ones((len(rows), row_numel)).astype("float32") np_array[0, 0] = 2.0 np_array[2, 8] = 4.0 + grad_tensor = grad_selected_rows.get_tensor() grad_tensor.set(np_array, place) @@ -76,6 +76,13 @@ class TestSparseSGDOp(unittest.TestCase): # rows[2] = 7, 5.0 - 2.0 * 4.0 self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) + def test_sparse_sgd(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.GPUPlace(0)) + for place in places: + self.check_with_place(place) + if __name__ == "__main__": unittest.main()