From b3f44ad761ae5e3d9afa5bbedecdb10d3926c8cd Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 13 Sep 2017 17:16:33 +0800 Subject: [PATCH 1/4] add multiplex operator --- paddle/operators/multiplex_op.cc | 107 ++++++++++++++++++ paddle/operators/multiplex_op.cu | 76 +++++++++++++ paddle/operators/multiplex_op.h | 68 +++++++++++ paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../v2/framework/tests/test_multiplex_op.py | 34 ++++++ 6 files changed, 287 insertions(+) create mode 100644 paddle/operators/multiplex_op.cc create mode 100644 paddle/operators/multiplex_op.cu create mode 100644 paddle/operators/multiplex_op.h create mode 100644 python/paddle/v2/framework/tests/test_multiplex_op.py diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc new file mode 100644 index 0000000000..67e8e5f5d7 --- /dev/null +++ b/paddle/operators/multiplex_op.cc @@ -0,0 +1,107 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/multiplex_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class MultiplexOp : public framework::OperatorWithKernel { + public: + MultiplexOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto ins = ctx.MultiInput("X"); + auto *out = ctx.Output("Out"); + auto num_ins = ins.size(); + PADDLE_ENFORCE(num_ins > 2, + "multiplex operator should have more than 2 inputs."); + PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, + "The first input must be a index vector."); + auto in_dim = ins[1]->dims(); + + for (size_t i = 2; i < num_ins; i++) { + auto dim = ins[i]->dims(); + PADDLE_ENFORCE( + in_dim == dim, + "All the input tensors except the first one must have the same size"); + } + out->Resize(in_dim); + } +}; + +class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MultiplexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of multiplex operator.").AsDuplicable(); + AddOutput("Out", "The output tensor of multiplex operator."); + AddComment(R"DOC(Multiplex operator + +Multiplex multiple tensors according to the index provided by the first +input tensor. + +ins[0]: the index of the tensor to output of size batchSize. +ins[1:N]: the candidate output tensor. +For each index i from 0 to batchSize - 1, the output is the i-th row of the +the (index[i] + 1)-th tensor. + +For each i-th row of output: + +y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) + +where y is the output tensor. `x_{k}` is the k-th input layer +and `k = x{0}[i] + 1`. + +)DOC"); + } +}; + +class MultiplexGradOp : public framework::OperatorWithKernel { + public: + MultiplexGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto ins = ctx.MultiInput("X"); + for (size_t i = 0; i < ins.size(); i++) { + auto dims = ins[i]->dims(); + d_ins[i]->Resize(dims); + } + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, + ops::MultiplexGradOp); +REGISTER_OP_CPU_KERNEL(multiplex, ops::MultiplexCPUKernel); +REGISTER_OP_CPU_KERNEL(multiplex_grad, ops::MultiplexGradCPUKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu new file mode 100644 index 0000000000..81d637686b --- /dev/null +++ b/paddle/operators/multiplex_op.cu @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MultiplexGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), paddle::platform::CPUPlace()); + auto index = index_t_cpu.data(); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + cudaMemcpy(out->data() + i * cols, ins[k]->data() + i * cols, + cols * sizeof(T), cudaMemcpyDeviceToDevice); + } + } +}; + +template +class MultiplexGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + for (auto d_in : d_ins) { + d_in->mutable_data(ctx.GetPlace()); + auto dims = d_in->dims(); + cudaMemset(d_in->data(), 0, framework::product(dims) * sizeof(T)); + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), paddle::platform::CPUPlace()); + auto index = index_t_cpu.data(); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + cudaMemcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, + cols * sizeof(T), cudaMemcpyDeviceToDevice); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(multiplex, ops::MultiplexGPUKernel); +REGISTER_OP_GPU_KERNEL(multiplex_grad, ops::MultiplexGradGPUKernel); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h new file mode 100644 index 0000000000..7b627a83b3 --- /dev/null +++ b/paddle/operators/multiplex_op.h @@ -0,0 +1,68 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class MultiplexCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto index = ins[0]->data(); + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + memcpy(out->data() + i * cols, ins[k]->data() + i * cols, + cols * sizeof(T)); + } + } +}; + +template +class MultiplexGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (auto d_in : d_ins) { + d_in->mutable_data(ctx.GetPlace()); + auto dims = d_in->dims(); + memset(d_in->data(), 0, framework::product(dims) * sizeof(T)); + } + + auto index = ins[0]->data(); + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + memcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, + cols * sizeof(T)); + } + } +}; +} +} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 16a2368aae..f0ac1f7f38 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -55,6 +55,7 @@ USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); USE_OP(reshape); +USE_OP(multiplex); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 6b22c00082..752c5a5265 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -36,3 +36,4 @@ py_test(mnist SRCS mnist.py) py_test(test_concat_op SRCS test_concat_op.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) py_test(test_reshape_op SRCS test_reshape_op.py) +py_test(test_multiplex_op SRCS test_multiplex_op.py) diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py new file mode 100644 index 0000000000..c42cb6f0fe --- /dev/null +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -0,0 +1,34 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestMultiplexOp(OpTest): + def setUp(self): + self.op_type = "multiplex" + rows = 3 + index = np.array([3, 1, 0]) + ins1 = np.random.random((rows, 10)).astype("float32") + ins2 = np.random.random((rows, 10)).astype("float32") + ins3 = np.random.random((rows, 10)).astype("float32") + ins4 = np.random.random((rows, 10)).astype("float32") + self.inputs = { + 'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), + ('x4', ins4)] + } + # multiplex output + output = np.zeros_like(ins1) + for i in range(0, rows): + k = index[i] + 1 + output[i] = self.inputs['X'][k][1][i] + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["x1"], "Out") + + +if __name__ == '__main__': + unittest.main() From 9da5192f771e43ec6e6f4cdaec2ba9ecd28337f1 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 20 Sep 2017 22:22:34 +0800 Subject: [PATCH 2/4] adapt multiplex_op to the dev of framework --- paddle/operators/multiplex_op.cc | 26 +++++++++++++------ paddle/operators/multiplex_op.cu | 20 +++++++++----- paddle/operators/multiplex_op.h | 18 ++++++++----- .../v2/framework/tests/test_multiplex_op.py | 11 +++++++- 4 files changed, 52 insertions(+), 23 deletions(-) diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 67e8e5f5d7..03559d0643 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -1,4 +1,3 @@ - /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,6 +18,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: @@ -29,8 +29,12 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) shouldn't be null."); auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto num_ins = ins.size(); PADDLE_ENFORCE(num_ins > 2, "multiplex operator should have more than 2 inputs."); @@ -53,7 +57,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { MultiplexOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor of multiplex operator.").AsDuplicable(); + AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC(Multiplex operator @@ -69,7 +73,7 @@ For each i-th row of output: y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) -where y is the output tensor. `x_{k}` is the k-th input layer +where y is the output tensor. `x_{k}` is the k-th input tensor and `k = x{0}[i] + 1`. )DOC"); @@ -86,13 +90,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), + "Output(X@Grad) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); - for (size_t i = 0; i < ins.size(); i++) { - auto dims = ins[i]->dims(); - d_ins[i]->Resize(dims); + // don;t compute gradient for index + for (size_t i = 1; i < ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->Resize(ins[i]->dims()); + } } } }; diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 81d637686b..055e13d183 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -18,13 +18,14 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; template class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto rows = ins[1]->dims()[0]; @@ -48,10 +49,13 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (auto d_in : d_ins) { - d_in->mutable_data(ctx.GetPlace()); - auto dims = d_in->dims(); - cudaMemset(d_in->data(), 0, framework::product(dims) * sizeof(T)); + for (size_t i = 1; i < d_ins.size(); ++i) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto dims = d_ins[i]->dims(); + cudaMemset(d_ins[i]->data(), 0, + framework::product(dims) * sizeof(T)); + } } auto rows = ins[1]->dims()[0]; @@ -62,8 +66,10 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto index = index_t_cpu.data(); for (auto i = 0; i < rows; i++) { int k = (int)index[i] + 1; - cudaMemcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, - cols * sizeof(T), cudaMemcpyDeviceToDevice); + if (d_ins[k]) { + cudaMemcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, + cols * sizeof(T), cudaMemcpyDeviceToDevice); + } } } }; diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 7b627a83b3..82b4a2c4c7 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -26,7 +26,7 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto index = ins[0]->data(); @@ -48,10 +48,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (auto d_in : d_ins) { - d_in->mutable_data(ctx.GetPlace()); - auto dims = d_in->dims(); - memset(d_in->data(), 0, framework::product(dims) * sizeof(T)); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto dims = d_ins[i]->dims(); + memset(d_ins[i]->data(), 0, framework::product(dims) * sizeof(T)); + } } auto index = ins[0]->data(); @@ -59,8 +61,10 @@ class MultiplexGradCPUKernel : public framework::OpKernel { auto cols = ins[1]->dims()[1]; for (auto i = 0; i < rows; i++) { int k = (int)index[i] + 1; - memcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, - cols * sizeof(T)); + if (d_ins[k]) { + memcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, + cols * sizeof(T)); + } } } }; diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py index c42cb6f0fe..f2b3881cde 100644 --- a/python/paddle/v2/framework/tests/test_multiplex_op.py +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -27,7 +27,16 @@ class TestMultiplexOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["x1"], "Out") + self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out') + + def test_check_grad_ignore_x1(self): + self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1')) + + def test_check_grad_ignore_x1_x2(self): + self.check_grad(['x3', 'x4'], 'Out', no_grad_set=set(['x1', 'x2'])) + + def test_check_grad_ignore_x3(self): + self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3')) if __name__ == '__main__': From 7620efdf1c1a6ec593c63ed017530e3fe8580f72 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sat, 23 Sep 2017 13:55:47 +0800 Subject: [PATCH 3/4] combine gpu&cpu code in multiplex_op --- paddle/operators/multiplex_op.cc | 26 +++++----- paddle/operators/multiplex_op.cu | 70 +++------------------------ paddle/operators/multiplex_op.h | 81 +++++++++++++++++++++++++------- 3 files changed, 81 insertions(+), 96 deletions(-) diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 03559d0643..6b22c782fe 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -22,10 +22,7 @@ using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: - MultiplexOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { @@ -64,12 +61,12 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { Multiplex multiple tensors according to the index provided by the first input tensor. -ins[0]: the index of the tensor to output of size batchSize. -ins[1:N]: the candidate output tensor. +ins[0]: the index tensor. +ins[1:N]: the candidate output tensors. For each index i from 0 to batchSize - 1, the output is the i-th row of the the (index[i] + 1)-th tensor. -For each i-th row of output: +For i-th row of the output tensor: y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) @@ -82,11 +79,7 @@ and `k = x{0}[i] + 1`. class MultiplexGradOp : public framework::OperatorWithKernel { public: - MultiplexGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { @@ -98,7 +91,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel { "Input(Out@GRAD) shouldn't be null."); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); - // don;t compute gradient for index + // don't compute gradient for index (ins[0]) for (size_t i = 1; i < ins.size(); i++) { if (d_ins[i]) { d_ins[i]->Resize(ins[i]->dims()); @@ -113,5 +106,8 @@ namespace ops = paddle::operators; REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, ops::MultiplexGradOp); -REGISTER_OP_CPU_KERNEL(multiplex, ops::MultiplexCPUKernel); -REGISTER_OP_CPU_KERNEL(multiplex_grad, ops::MultiplexGradCPUKernel); +REGISTER_OP_CPU_KERNEL(multiplex, + ops::MultiplexKernel); +REGISTER_OP_CPU_KERNEL( + multiplex_grad, + ops::MultiplexGradKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 055e13d183..3d219389ba 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -13,70 +13,12 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -template -class MultiplexGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - // copy index to cpu - Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), paddle::platform::CPUPlace()); - auto index = index_t_cpu.data(); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - cudaMemcpy(out->data() + i * cols, ins[k]->data() + i * cols, - cols * sizeof(T), cudaMemcpyDeviceToDevice); - } - } -}; - -template -class MultiplexGradGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); ++i) { - if (d_ins[i]) { - d_ins[i]->mutable_data(ctx.GetPlace()); - auto dims = d_ins[i]->dims(); - cudaMemset(d_ins[i]->data(), 0, - framework::product(dims) * sizeof(T)); - } - } - - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - // copy index to cpu - Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), paddle::platform::CPUPlace()); - auto index = index_t_cpu.data(); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - if (d_ins[k]) { - cudaMemcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, - cols * sizeof(T), cudaMemcpyDeviceToDevice); - } - } - } -}; -} // namespace operators -} // namespace paddle +#include "paddle/operators/multiplex_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(multiplex, ops::MultiplexGPUKernel); -REGISTER_OP_GPU_KERNEL(multiplex_grad, ops::MultiplexGradGPUKernel); +REGISTER_OP_GPU_KERNEL(multiplex, + ops::MultiplexKernel); +REGISTER_OP_GPU_KERNEL( + multiplex_grad, + ops::MultiplexGradKernel); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 82b4a2c4c7..dcc01d0f98 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -17,31 +17,56 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" namespace paddle { namespace operators { -template -class MultiplexCPUKernel : public framework::OpKernel { +template +class MultiplexKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); - auto index = ins[0]->data(); auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - memcpy(out->data() + i * cols, ins[k]->data() + i * cols, - cols * sizeof(T)); + if (platform::is_cpu_place(ctx.GetPlace())) { + auto* index = ins[0]->data(); + platform::CPUPlace place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T)); + } + } else { +#ifndef PADDLE_ONLY_CPU + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + platform::GPUPlace place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T), stream); + } +#endif } } }; -template -class MultiplexGradCPUKernel : public framework::OpKernel { +template +class MultiplexGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); @@ -51,20 +76,42 @@ class MultiplexGradCPUKernel : public framework::OpKernel { for (size_t i = 1; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); - auto dims = d_ins[i]->dims(); - memset(d_ins[i]->data(), 0, framework::product(dims) * sizeof(T)); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); } } - auto index = ins[0]->data(); auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - if (d_ins[k]) { - memcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, - cols * sizeof(T)); + if (platform::is_cpu_place(ctx.GetPlace())) { + auto* index = ins[0]->data(); + platform::CPUPlace place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T)); + } + } + } else { +#ifndef PADDLE_ONLY_CPU + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + platform::GPUPlace place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T), stream); + } } +#endif } } }; From fb52bc6e122d249b3e2d8168de81f9e52b980322 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 25 Sep 2017 11:18:38 +0800 Subject: [PATCH 4/4] revert code layout in multiplex_op --- paddle/operators/multiplex_op.cc | 6 +-- paddle/operators/multiplex_op.cu | 77 ++++++++++++++++++++++++++++++-- paddle/operators/multiplex_op.h | 75 +++++++------------------------ 3 files changed, 94 insertions(+), 64 deletions(-) diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 6b22c782fe..6e77b86b56 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -106,8 +106,8 @@ namespace ops = paddle::operators; REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, ops::MultiplexGradOp); -REGISTER_OP_CPU_KERNEL(multiplex, - ops::MultiplexKernel); +REGISTER_OP_CPU_KERNEL( + multiplex, ops::MultiplexCPUKernel); REGISTER_OP_CPU_KERNEL( multiplex_grad, - ops::MultiplexGradKernel); + ops::MultiplexGradCPUKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 3d219389ba..4736f15bd5 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -15,10 +15,81 @@ #include "paddle/framework/op_registry.h" #include "paddle/operators/multiplex_op.h" +namespace paddle { +namespace operators { + +template +class MultiplexGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T), stream); + } + } +}; + +template +class MultiplexGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); + } + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T), stream); + } + } + } +}; +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(multiplex, - ops::MultiplexKernel); +REGISTER_OP_GPU_KERNEL( + multiplex, ops::MultiplexGPUKernel); REGISTER_OP_GPU_KERNEL( multiplex_grad, - ops::MultiplexGradKernel); + ops::MultiplexGradGPUKernel); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index dcc01d0f98..44e8e0c199 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -23,7 +23,7 @@ namespace paddle { namespace operators { template -class MultiplexKernel : public framework::OpKernel { +class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); @@ -33,40 +33,20 @@ class MultiplexKernel : public framework::OpKernel { auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; - if (platform::is_cpu_place(ctx.GetPlace())) { - auto* index = ins[0]->data(); - platform::CPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - PADDLE_ENFORCE_LT(k, ins.size(), - "index exceeds the number of candidate tensors."); - memory::Copy(place, out->data() + i * cols, place, - ins[k]->data() + i * cols, cols * sizeof(T)); - } - } else { -#ifndef PADDLE_ONLY_CPU - // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); - platform::GPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - PADDLE_ENFORCE_LT(k, ins.size(), - "index exceeds the number of candidate tensors."); - memory::Copy(place, out->data() + i * cols, place, - ins[k]->data() + i * cols, cols * sizeof(T), stream); - } -#endif + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T)); } } }; template -class MultiplexGradKernel : public framework::OpKernel { +class MultiplexGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); @@ -83,35 +63,14 @@ class MultiplexGradKernel : public framework::OpKernel { auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; - if (platform::is_cpu_place(ctx.GetPlace())) { - auto* index = ins[0]->data(); - platform::CPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - if (d_ins[k]) { - memory::Copy(place, d_ins[k]->data() + i * cols, place, - d_out->data() + i * cols, cols * sizeof(T)); - } - } - } else { -#ifndef PADDLE_ONLY_CPU - // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); - - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); - platform::GPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - if (d_ins[k]) { - memory::Copy(place, d_ins[k]->data() + i * cols, place, - d_out->data() + i * cols, cols * sizeof(T), stream); - } + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T)); } -#endif } } };