From 9bd9d8b5ca96ff442a7ba3a3df0564e414c11af5 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 18 Jan 2018 11:29:09 +0800 Subject: [PATCH 1/5] Add sequence_reshape_op. --- paddle/operators/sequence_reshape_op.cc | 78 +++++++++++++++ paddle/operators/sequence_reshape_op.h | 127 ++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 paddle/operators/sequence_reshape_op.cc create mode 100644 paddle/operators/sequence_reshape_op.h diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc new file mode 100644 index 0000000000..31a970354f --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" + +namespace paddle { +namespace operators { + +class SequenceReshapeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceReshapeOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); + int dimension = ctx->Attrs().Get("dimension"); + ctx->SetOutputDim("Out", {{x_dims[0], static_cast(dimension)}}); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", ""); + AddOutput("Out", ""); + AddAttr("dimension", ""); + AddAttr("is_padding", "Default padding zero."); + AddComment(R"DOC()DOC"); + } +}; + +class SequenceReshapeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceReshapeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input(Out) of SequenceReshapeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeGradOp should not be null."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp, + ops::SequenceReshapeOpMaker); +REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel); +REGISTER_OP_CPU_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h new file mode 100644 index 0000000000..bc7694b6b1 --- /dev/null +++ b/paddle/operators/sequence_reshape_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +template +class SequenceReshapeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int out_width = context.Attr("dimension"); + bool whether_padding = context.Attr("whether_padding"); + + const T* p_in_data = in->data(); + T* p_out_data = out->mutable_data(context.GetPlace()); + + // compute shape for output + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + auto& in_lod = in->lod(); + + PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_GE( + in_dims[0], + /* batch size = */ static_cast(in_lod[0].size() - 1), + "The 1st dimension of Input(X) must be equal or larger than batch " + "size."); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + auto& out_lod = *out->mutable_lod(); + out_lod.push_back(std::vector({0})); + size_t offset = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + if (whether_padding) { + offset += std::ceil((float)(seq_len * in_width) / out_width); + } else { + offset += (seq_len * in_width) / out_width; + } + out_lod[0].push_back(offset); + } + + out->Resize({{static_cast(out_lod[0].back()), out_width}}); + math::set_constant(context.device_context(), out, 0.0f); + + for (int i = 0; i < seq_num; ++i) { + size_t in_offset = in_lod_l0[i] * in_width; + size_t out_offset = out_lod[0][i] * out_width; + size_t bytes = sizeof(T) * (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width; + if (platform::is_cpu_place(context.GetPlace())) { + std::memcpy(p_out_data + out_offset, p_in_data + in_offset, bytes); + } else { +#ifdef PADDLE_WITH_CUDA + auto& dev_ctx = context.template device_context(); + memory::Copy(boost::get(context.GetPlace()), + p_out_data + out_offset, + boost::get(context.GetPlace()), + p_in_data + in_offset, bytes, dev_ctx.stream()); +#endif + } + } + } +}; + +template +class SequenceReshapeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_tensor_ptr = context.Input("X"); + auto* out_tensor_ptr = context.Input("Out"); + auto* out_grad_tensor_ptr = + context.Input(framework::GradVarName("Out")); + auto* x_grad_tensor_ptr = + context.Output(framework::GradVarName("X")); + + T* p_x_grad_data = x_grad_tensor_ptr->mutable_data(context.GetPlace()); + const T* p_out_grad_data = out_grad_tensor_ptr->data(); + + auto& x_lod = x_tensor_ptr->lod(); + int seq_num = x_lod[0].size() - 1; + int x_width = x_tensor_ptr->dims()[1]; + auto& out_lod = out_tensor_ptr->lod(); + int out_width = out_tensor_ptr->dims()[1]; + + for (int i = 0; i < seq_num; ++i) { + size_t src_offset = out_lod[0][i] * out_width; + size_t dst_offset = x_lod[0][i] * x_width; + size_t bytes = sizeof(T) * (x_lod[0][i + 1] - x_lod[0][i]) * x_width; + if (platform::is_cpu_place(context.GetPlace())) { + std::memcpy(p_x_grad_data + dst_offset, p_out_grad_data + src_offset, + bytes); + } else { +#ifdef PADDLE_WITH_CUDA + auto& dev_ctx = context.template device_context(); + memory::Copy(boost::get(context.GetPlace()), + p_x_grad_data + dst_offset, + boost::get(context.GetPlace()), + p_out_grad_data + src_offset, bytes, dev_ctx.stream()); +#endif + } + } + } +}; + +} // namespace operators +} // namespace paddle From bea41444d78faba95a0e7f7fda79edf62afbe7cf Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 18 Jan 2018 20:16:03 +0800 Subject: [PATCH 2/5] Refine the implementation and add unit test. --- paddle/operators/sequence_reshape_op.cc | 64 +++++++++++++--- paddle/operators/sequence_reshape_op.cu | 23 ++++++ paddle/operators/sequence_reshape_op.h | 67 ++++++++++------ .../v2/fluid/tests/test_sequence_reshape.py | 76 +++++++++++++++++++ 4 files changed, 196 insertions(+), 34 deletions(-) create mode 100644 paddle/operators/sequence_reshape_op.cu create mode 100644 python/paddle/v2/fluid/tests/test_sequence_reshape.py diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc index 31a970354f..308de59c64 100644 --- a/paddle/operators/sequence_reshape_op.cc +++ b/paddle/operators/sequence_reshape_op.cc @@ -27,9 +27,8 @@ class SequenceReshapeOp : public framework::OperatorWithKernel { "Output(Out) of SequenceReshapeOp should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); - int dimension = ctx->Attrs().Get("dimension"); - ctx->SetOutputDim("Out", {{x_dims[0], static_cast(dimension)}}); - ctx->ShareLoD("X", /*->*/ "Out"); + int dimension = ctx->Attrs().Get("new_dim"); + ctx->SetOutputDim("Out", {x_dims[0], static_cast(dimension)}); } }; @@ -37,11 +36,41 @@ class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker { public: SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", ""); - AddOutput("Out", ""); - AddAttr("dimension", ""); - AddAttr("is_padding", "Default padding zero."); - AddComment(R"DOC()DOC"); + AddInput("X", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor with shape " + "being [N, M]."); + AddOutput("Out", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor with " + "shape [T, new_dim] where T is calculated based on X.lod, M and " + "new_dim."); + AddAttr("new_dim", "Sequence dimension of the output LoDTensor."); + AddComment(R"DOC( +Sequence Reshape Operator. + +This operator will rearrange the input sequences. The new dimension is set by +attribute and length of each sequence may change longer or shorter which is +decided by original length, original dimension and new dimension. The following +example will help to illustrate the function of this operator: + +x is a LoDTensor: + x.lod = [[0, 2, 6]] + x.data = [[0.1, 0.2], [0.3, 0.4], + [0.5, 0.6], [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]] + x.dims = [6, 2] + +set new_dim = 4 + +then out is a LoDTensor: + out.lod = [[0, 1, 3]] + out.data = [[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]] + out.dims = [3, 4] + +Currently, only 1-level LoDTensor is supported and please make sure (original +length * original dimension) can be divided by new_dim with no remainder for +each sequence. + +)DOC"); } }; @@ -63,12 +92,29 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel { } }; +class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("sequence_reshape_grad"); + op_desc_ptr->SetInput("X", Input("X")); + op_desc_ptr->SetInput("Out", Output("Out")); + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp, - ops::SequenceReshapeOpMaker); + ops::SequenceReshapeOpMaker, ops::SequenceReshapeGradOpMaker); REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp); REGISTER_OP_CPU_KERNEL( sequence_reshape, diff --git a/paddle/operators/sequence_reshape_op.cu b/paddle/operators/sequence_reshape_op.cu new file mode 100644 index 0000000000..dc620ef522 --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h index bc7694b6b1..8e302a364b 100644 --- a/paddle/operators/sequence_reshape_op.h +++ b/paddle/operators/sequence_reshape_op.h @@ -26,53 +26,63 @@ class SequenceReshapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - int out_width = context.Attr("dimension"); - bool whether_padding = context.Attr("whether_padding"); + int out_width = context.Attr("new_dim"); const T* p_in_data = in->data(); - T* p_out_data = out->mutable_data(context.GetPlace()); - // compute shape for output auto in_dims = in->dims(); int64_t in_width = in_dims[1]; auto& in_lod = in->lod(); PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, "Only support one level sequence now."); - PADDLE_ENFORCE_GE( - in_dims[0], - /* batch size = */ static_cast(in_lod[0].size() - 1), - "The 1st dimension of Input(X) must be equal or larger than batch " - "size."); + PADDLE_ENFORCE_EQ( + in_dims[0], in_lod[0].back(), + "Inconsistent size between X.shape[0] and X.lod()[0].back()."); auto in_lod_l0 = in_lod[0]; int seq_num = in_lod_l0.size() - 1; auto& out_lod = *out->mutable_lod(); - out_lod.push_back(std::vector({0})); - size_t offset = 0; + out_lod.resize(1); + out_lod[0].clear(); + out_lod[0].push_back(0); for (int i = 0; i < seq_num; ++i) { size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; - if (whether_padding) { - offset += std::ceil((float)(seq_len * in_width) / out_width); - } else { - offset += (seq_len * in_width) / out_width; - } - out_lod[0].push_back(offset); + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width, + "Please make sure (sequence_length * dimension) can be " + "divided by new_dim with no remainder for each " + "sequence. The %dth sequence is invalid.", + i + 1); + PADDLE_ENFORCE_GT(offset, 0, + "Illegal operation, length of the %dth sequence become " + "to 0 after reshaped.", + i + 1); + out_lod[0].push_back(out_lod[0].back() + offset); } - out->Resize({{static_cast(out_lod[0].back()), out_width}}); + out->mutable_data(context.GetPlace()); + out->Resize({static_cast(out_lod[0].back()), out_width}); + T* p_out_data = out->mutable_data(context.GetPlace()); math::set_constant(context.device_context(), out, 0.0f); for (int i = 0; i < seq_num; ++i) { size_t in_offset = in_lod_l0[i] * in_width; size_t out_offset = out_lod[0][i] * out_width; - size_t bytes = sizeof(T) * (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width; + size_t in_count = (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width; + size_t out_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width; + size_t bytes = sizeof(T) * std::min(in_count, out_count); if (platform::is_cpu_place(context.GetPlace())) { - std::memcpy(p_out_data + out_offset, p_in_data + in_offset, bytes); + memory::Copy(boost::get(context.GetPlace()), + p_out_data + out_offset, + boost::get(context.GetPlace()), + p_in_data + in_offset, bytes); } else { #ifdef PADDLE_WITH_CUDA - auto& dev_ctx = context.template device_context(); + auto& dev_ctx = + context.template device_context(); memory::Copy(boost::get(context.GetPlace()), p_out_data + out_offset, boost::get(context.GetPlace()), @@ -103,16 +113,23 @@ class SequenceReshapeGradKernel : public framework::OpKernel { auto& out_lod = out_tensor_ptr->lod(); int out_width = out_tensor_ptr->dims()[1]; + math::set_constant(context.device_context(), x_grad_tensor_ptr, 0.0f); + for (int i = 0; i < seq_num; ++i) { size_t src_offset = out_lod[0][i] * out_width; size_t dst_offset = x_lod[0][i] * x_width; - size_t bytes = sizeof(T) * (x_lod[0][i + 1] - x_lod[0][i]) * x_width; + size_t src_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width; + size_t dst_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width; + size_t bytes = sizeof(T) * std::min(src_count, dst_count); if (platform::is_cpu_place(context.GetPlace())) { - std::memcpy(p_x_grad_data + dst_offset, p_out_grad_data + src_offset, - bytes); + memory::Copy(boost::get(context.GetPlace()), + p_x_grad_data + dst_offset, + boost::get(context.GetPlace()), + p_out_grad_data + src_offset, bytes); } else { #ifdef PADDLE_WITH_CUDA - auto& dev_ctx = context.template device_context(); + auto& dev_ctx = + context.template device_context(); memory::Copy(boost::get(context.GetPlace()), p_x_grad_data + dst_offset, boost::get(context.GetPlace()), diff --git a/python/paddle/v2/fluid/tests/test_sequence_reshape.py b/python/paddle/v2/fluid/tests/test_sequence_reshape.py new file mode 100644 index 0000000000..91ff275821 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_reshape.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +import math +from op_test import OpTest + + +class TestSequenceReshape(OpTest): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 12 + x_lod = [[0, 4, 5, 8, 11]] + x = np.random.uniform(0.1, 1, [11, 24]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + def compute_output(self, x, x_lod, dimension): + x_width = x.shape[1] + out_lod = [[0]] + for i in xrange(len(x_lod[0]) - 1): + seq_len = x_lod[0][i + 1] - x_lod[0][i] + offset = (seq_len * x_width) / dimension + assert int(offset) * dimension == seq_len * x_width + out_lod[0].append(out_lod[0][-1] + int(offset)) + out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32') + for i in xrange(len(x_lod[0]) - 1): + x_offset = x_lod[0][i] * x_width + out_offset = out_lod[0][i] * dimension + out_count = (out_lod[0][i + 1] - out_lod[0][i]) * dimension + x_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width + count = min(out_count, x_count) + out.ravel()[out_offset:out_offset + count] = x.ravel()[ + x_offset:x_offset + count] + return out, out_lod + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequenceReshape_reduce(TestSequenceReshape): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 24 + x_lod = [[0, 4, 6, 8, 12]] + x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + +if __name__ == '__main__': + unittest.main() From fc581bc5f20e3dd3a2a518a1eb5120abf89a3a52 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 18 Jan 2018 20:27:43 +0800 Subject: [PATCH 3/5] Change the CopyRight. --- paddle/operators/sequence_reshape_op.cc | 2 +- paddle/operators/sequence_reshape_op.cu | 2 +- paddle/operators/sequence_reshape_op.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc index 308de59c64..ddedbc3bc6 100644 --- a/paddle/operators/sequence_reshape_op.cc +++ b/paddle/operators/sequence_reshape_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/operators/sequence_reshape_op.cu b/paddle/operators/sequence_reshape_op.cu index dc620ef522..9ba0e34e27 100644 --- a/paddle/operators/sequence_reshape_op.cu +++ b/paddle/operators/sequence_reshape_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h index 8e302a364b..7c2215f772 100644 --- a/paddle/operators/sequence_reshape_op.h +++ b/paddle/operators/sequence_reshape_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 08cb472ab90495d536a91b63135930c2397974b6 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 19 Jan 2018 12:42:54 +0800 Subject: [PATCH 4/5] Simplify the implementation. --- paddle/operators/sequence_reshape_op.cc | 30 +++-- paddle/operators/sequence_reshape_op.cu | 11 +- paddle/operators/sequence_reshape_op.h | 107 ++++-------------- .../v2/fluid/tests/test_sequence_reshape.py | 24 ++-- 4 files changed, 68 insertions(+), 104 deletions(-) diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc index ddedbc3bc6..884c49276c 100644 --- a/paddle/operators/sequence_reshape_op.cc +++ b/paddle/operators/sequence_reshape_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/sequence_reshape_op.h" +#include "paddle/framework/ddim.h" namespace paddle { namespace operators { @@ -26,9 +27,11 @@ class SequenceReshapeOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequenceReshapeOp should not be null."); auto x_dims = ctx->GetInputDim("X"); + auto x_numel = product(x_dims); PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); - int dimension = ctx->Attrs().Get("new_dim"); - ctx->SetOutputDim("Out", {x_dims[0], static_cast(dimension)}); + int new_dim = ctx->Attrs().Get("new_dim"); + ctx->SetOutputDim("Out", + {x_numel / new_dim, static_cast(new_dim)}); } }; @@ -54,16 +57,16 @@ example will help to illustrate the function of this operator: x is a LoDTensor: x.lod = [[0, 2, 6]] - x.data = [[0.1, 0.2], [0.3, 0.4], - [0.5, 0.6], [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]] + x.data = [[1, 2], [3, 4], + [5, 6], [7, 8], [9, 10], [11, 12]] x.dims = [6, 2] set new_dim = 4 then out is a LoDTensor: - out.lod = [[0, 1, 3]] - out.data = [[0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]] + out.lod = [[0, 1, 3]] + out.data = [[1, 2, 3, 4], + [5, 6, 7, 8], [9, 10, 11, 12]] out.dims = [3, 4] Currently, only 1-level LoDTensor is supported and please make sure (original @@ -82,8 +85,6 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) of SequenceReshapeGradOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Out"), - "Input(Out) of SequenceReshapeGradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequenceReshapeGradOp should not be null."); @@ -101,7 +102,6 @@ class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker { auto* op_desc_ptr = new framework::OpDesc(); op_desc_ptr->SetType("sequence_reshape_grad"); op_desc_ptr->SetInput("X", Input("X")); - op_desc_ptr->SetInput("Out", Output("Out")); op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); op_desc_ptr->SetAttrMap(Attrs()); @@ -118,7 +118,13 @@ REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp, REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp); REGISTER_OP_CPU_KERNEL( sequence_reshape, - ops::SequenceReshapeKernel); + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel); REGISTER_OP_CPU_KERNEL( sequence_reshape_grad, - ops::SequenceReshapeGradKernel); + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.cu b/paddle/operators/sequence_reshape_op.cu index 9ba0e34e27..d9c2f7e9a4 100644 --- a/paddle/operators/sequence_reshape_op.cu +++ b/paddle/operators/sequence_reshape_op.cu @@ -17,7 +17,14 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sequence_reshape, - ops::SequenceReshapeKernel); + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel); REGISTER_OP_CUDA_KERNEL( sequence_reshape_grad, - ops::SequenceReshapeGradKernel); + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h index 7c2215f772..623904ec7c 100644 --- a/paddle/operators/sequence_reshape_op.h +++ b/paddle/operators/sequence_reshape_op.h @@ -28,8 +28,6 @@ class SequenceReshapeKernel : public framework::OpKernel { auto* out = context.Output("Out"); int out_width = context.Attr("new_dim"); - const T* p_in_data = in->data(); - auto in_dims = in->dims(); int64_t in_width = in_dims[1]; auto& in_lod = in->lod(); @@ -43,53 +41,29 @@ class SequenceReshapeKernel : public framework::OpKernel { auto in_lod_l0 = in_lod[0]; int seq_num = in_lod_l0.size() - 1; - auto& out_lod = *out->mutable_lod(); - out_lod.resize(1); - out_lod[0].clear(); - out_lod[0].push_back(0); - for (int i = 0; i < seq_num; ++i) { - size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; - size_t offset = 0; - offset = (seq_len * in_width) / out_width; - PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width, - "Please make sure (sequence_length * dimension) can be " - "divided by new_dim with no remainder for each " - "sequence. The %dth sequence is invalid.", - i + 1); - PADDLE_ENFORCE_GT(offset, 0, - "Illegal operation, length of the %dth sequence become " - "to 0 after reshaped.", - i + 1); - out_lod[0].push_back(out_lod[0].back() + offset); + if (in_width == out_width) { + out->set_lod(in->lod()); + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].clear(); + out_lod[0].push_back(0); + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width, + "Please make sure (sequence_length * dimension) can " + "be divided by new_dim with no remainder for each " + "sequence. The %dth sequence is invalid.", + i + 1); + out_lod[0].push_back(out_lod[0].back() + offset); + } } out->mutable_data(context.GetPlace()); - out->Resize({static_cast(out_lod[0].back()), out_width}); - T* p_out_data = out->mutable_data(context.GetPlace()); - math::set_constant(context.device_context(), out, 0.0f); - - for (int i = 0; i < seq_num; ++i) { - size_t in_offset = in_lod_l0[i] * in_width; - size_t out_offset = out_lod[0][i] * out_width; - size_t in_count = (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width; - size_t out_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width; - size_t bytes = sizeof(T) * std::min(in_count, out_count); - if (platform::is_cpu_place(context.GetPlace())) { - memory::Copy(boost::get(context.GetPlace()), - p_out_data + out_offset, - boost::get(context.GetPlace()), - p_in_data + in_offset, bytes); - } else { -#ifdef PADDLE_WITH_CUDA - auto& dev_ctx = - context.template device_context(); - memory::Copy(boost::get(context.GetPlace()), - p_out_data + out_offset, - boost::get(context.GetPlace()), - p_in_data + in_offset, bytes, dev_ctx.stream()); -#endif - } - } + framework::Copy(*in, context.GetPlace(), out); + out->Resize({static_cast(out->lod()[0].back()), out_width}); } }; @@ -98,45 +72,14 @@ class SequenceReshapeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x_tensor_ptr = context.Input("X"); - auto* out_tensor_ptr = context.Input("Out"); - auto* out_grad_tensor_ptr = + auto* outg_tensor_ptr = context.Input(framework::GradVarName("Out")); - auto* x_grad_tensor_ptr = + auto* xg_tensor_ptr = context.Output(framework::GradVarName("X")); - T* p_x_grad_data = x_grad_tensor_ptr->mutable_data(context.GetPlace()); - const T* p_out_grad_data = out_grad_tensor_ptr->data(); - - auto& x_lod = x_tensor_ptr->lod(); - int seq_num = x_lod[0].size() - 1; - int x_width = x_tensor_ptr->dims()[1]; - auto& out_lod = out_tensor_ptr->lod(); - int out_width = out_tensor_ptr->dims()[1]; - - math::set_constant(context.device_context(), x_grad_tensor_ptr, 0.0f); - - for (int i = 0; i < seq_num; ++i) { - size_t src_offset = out_lod[0][i] * out_width; - size_t dst_offset = x_lod[0][i] * x_width; - size_t src_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width; - size_t dst_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width; - size_t bytes = sizeof(T) * std::min(src_count, dst_count); - if (platform::is_cpu_place(context.GetPlace())) { - memory::Copy(boost::get(context.GetPlace()), - p_x_grad_data + dst_offset, - boost::get(context.GetPlace()), - p_out_grad_data + src_offset, bytes); - } else { -#ifdef PADDLE_WITH_CUDA - auto& dev_ctx = - context.template device_context(); - memory::Copy(boost::get(context.GetPlace()), - p_x_grad_data + dst_offset, - boost::get(context.GetPlace()), - p_out_grad_data + src_offset, bytes, dev_ctx.stream()); -#endif - } - } + xg_tensor_ptr->mutable_data(context.GetPlace()); + framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr); + xg_tensor_ptr->Resize(x_tensor_ptr->dims()); } }; diff --git a/python/paddle/v2/fluid/tests/test_sequence_reshape.py b/python/paddle/v2/fluid/tests/test_sequence_reshape.py index 91ff275821..857b15237a 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_reshape.py +++ b/python/paddle/v2/fluid/tests/test_sequence_reshape.py @@ -40,14 +40,7 @@ class TestSequenceReshape(OpTest): assert int(offset) * dimension == seq_len * x_width out_lod[0].append(out_lod[0][-1] + int(offset)) out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32') - for i in xrange(len(x_lod[0]) - 1): - x_offset = x_lod[0][i] * x_width - out_offset = out_lod[0][i] * dimension - out_count = (out_lod[0][i + 1] - out_lod[0][i]) * dimension - x_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width - count = min(out_count, x_count) - out.ravel()[out_offset:out_offset + count] = x.ravel()[ - x_offset:x_offset + count] + out.ravel()[:] = x.ravel()[:] return out, out_lod def test_check_output(self): @@ -72,5 +65,20 @@ class TestSequenceReshape_reduce(TestSequenceReshape): self.outputs = {'Out': (out, out_lod)} +class TestSequenceReshape_same(TestSequenceReshape): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 12 + x_lod = [[0, 4, 6, 8, 12]] + x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + if __name__ == '__main__': unittest.main() From b07ca1de1fa53a6cafa177b474bd6d79b3c1aef6 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 19 Jan 2018 16:03:24 +0800 Subject: [PATCH 5/5] resize before computing LoD. --- paddle/operators/sequence_reshape_op.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h index 623904ec7c..dd9b611250 100644 --- a/paddle/operators/sequence_reshape_op.h +++ b/paddle/operators/sequence_reshape_op.h @@ -46,8 +46,8 @@ class SequenceReshapeKernel : public framework::OpKernel { } else { auto& out_lod = *out->mutable_lod(); out_lod.resize(1); - out_lod[0].clear(); - out_lod[0].push_back(0); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; for (int i = 0; i < seq_num; ++i) { size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; size_t offset = 0; @@ -57,11 +57,10 @@ class SequenceReshapeKernel : public framework::OpKernel { "be divided by new_dim with no remainder for each " "sequence. The %dth sequence is invalid.", i + 1); - out_lod[0].push_back(out_lod[0].back() + offset); + out_lod[0][i + 1] = out_lod[0][i] + offset; } } - out->mutable_data(context.GetPlace()); framework::Copy(*in, context.GetPlace(), out); out->Resize({static_cast(out->lod()[0].back()), out_width}); }