From 332b665fc789efb7249ee9791af714842ea68e66 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Mar 2018 17:56:12 +0800 Subject: [PATCH 1/3] Enhanced cpp implementation and unit test. --- paddle/fluid/operators/lod_reset_op.cc | 79 +++++++++++-------- paddle/fluid/operators/lod_reset_op.cu | 8 +- paddle/fluid/operators/lod_reset_op.h | 43 ++++++---- .../tests/unittests/test_lod_reset_op.py | 25 +++++- 4 files changed, 101 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 6a66297cb8..6599e183ef 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - // input check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LoDResetOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of LoDResetOp should not be null."); - // If target LoD is not set form Input(), then it must be set from Attr(). - if (!ctx->HasInput("TargetLoD")) { + + if (!ctx->HasInput("Y")) { auto level0 = ctx->Attrs().Get>("target_lod"); - PADDLE_ENFORCE(level0.size() > 1, - "Target LoD is not found, should be set to be a valid one " - "through Input() or Attr()."); + PADDLE_ENFORCE_GT(level0.size(), 1, + "If Input(Y) is not provided, the target lod should be " + "specified by attribute `target_lod`."); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); } @@ -50,36 +49,42 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { public: LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(LoDTensor) The input tensor of lod_reset operator."); - AddInput("TargetLoD", - "(Tensor, optional) The target level 0 LoD from Input().") + AddInput("X", + "(Tensor, LoDTensor) Input variable of LoDResetOp which " + "could be a Tensor or LoDTensor, where the data of output " + "variable inherits from."); + AddInput("Y", + "(Tensor, LoDTensor, optional) If provided, lod of Input(Y) would " + "be considered as the target lod first, otherwise data of " + "Input(Y) would be considered as the target lod.") .AsDispensable(); - AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator."); + AddOutput("Out", + "(LoDTensor) Output variable of LoDResetOp which should be a " + "LoDTensor."); AddAttr>("target_lod", "The target level 0 LoD from Attr().") .SetDefault(std::vector{}); AddComment(R"DOC(LoDReset operator -Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or -Attr(target_lod), or set LoD for Input(X) if it doesn't have one. -Currently the lod_reset operator only supports the reset of level 0 LoD. -At least one of Input(TargetLoD) and Attr(target_lod) must be set, -and if both of them are set, Input(TargetLoD) will be chosen as the -target LoD. +Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y` +provided, `Y.lod` would be considered as target LoD first, otherwise `Y.data` +would be considered as target LoD. If `Y` is not provided, target LoD should be +specified by attribute `target_lod`. If target LoD is specified by `Y.data` or +`target_lod`, only one level LoD is supported. An example: -Given a float LoDTensor X with shape (6, 1), its transpose form represents - - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], -with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like +Given a 1-level LoDTensor input(X) + X.lod = [[ 0, 2, 5 6 ]] + X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + X.dims = [6, 1] - [1.0, 2.0], [3.0, 4.0, 5.0], [6.0]. +target_lod: [0, 4, 6] -If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and -the sequences that the LoDTensor Output(Out) contains becomes: - - [1.0, 2.0, 3.0, 4.0], [5.0, 6.0]. +then we get an 1-level LoDTensor + Out.lod = [[ 0, 4, 6 ]] + Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + Out.dims = [6, 1] )DOC"); } @@ -90,10 +95,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LoDResetGradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + "Input(Out@Grad) of LoDResetGradOp should not be null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } } protected: @@ -111,9 +122,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad, ops::LoDResetGradOp); -REGISTER_OP_CPU_KERNEL(lod_reset, - ops::LoDResetKernel, - ops::LoDResetKernel); +REGISTER_OP_CPU_KERNEL( + lod_reset, ops::LoDResetKernel, + ops::LoDResetKernel, + ops::LoDResetKernel, + ops::LoDResetKernel); REGISTER_OP_CPU_KERNEL( lod_reset_grad, ops::LoDResetGradKernel, - ops::LoDResetGradKernel); + ops::LoDResetGradKernel, + ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/fluid/operators/lod_reset_op.cu b/paddle/fluid/operators/lod_reset_op.cu index b0e87a851a..888d4c12eb 100644 --- a/paddle/fluid/operators/lod_reset_op.cu +++ b/paddle/fluid/operators/lod_reset_op.cu @@ -18,8 +18,12 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( lod_reset, ops::LoDResetKernel, - ops::LoDResetKernel); + ops::LoDResetKernel, + ops::LoDResetKernel, + ops::LoDResetKernel); REGISTER_OP_CUDA_KERNEL( lod_reset_grad, ops::LoDResetGradKernel, - ops::LoDResetGradKernel); + ops::LoDResetGradKernel, + ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index 8186d4f826..99f01c2a25 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -26,35 +26,46 @@ class LoDResetKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); - auto* lod_t = ctx.Input("TargetLoD"); + auto* lod_t = ctx.Input("Y"); + + out->ShareDataWith(*in); std::vector level0; if (lod_t) { - auto* lod = lod_t->data(); - if (platform::is_gpu_place(ctx.GetPlace())) { - framework::Tensor lod_cpu; - framework::TensorCopy(*lod_t, platform::CPUPlace(), - ctx.device_context(), &lod_cpu); - lod = lod_cpu.data(); + if (lod_t->lod().size() > 0) { + auto y_lod = lod_t->lod(); + auto last_level = y_lod[y_lod.size() - 1]; + PADDLE_ENFORCE_EQ(last_level.back(), in->dims()[0], + "Last value of `Y`'s last level LoD should be equal " + "to the first dimension of `X`"); + out->set_lod(y_lod); + return; // early return, since lod already set + } else { + auto* lod = lod_t->data(); + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor lod_cpu; + framework::TensorCopy(*lod_t, platform::CPUPlace(), + ctx.device_context(), &lod_cpu); + lod = lod_cpu.data(); + } + level0 = std::vector(lod, lod + lod_t->numel()); } - level0 = std::vector(lod, lod + lod_t->numel()); } else { level0 = ctx.Attr>("target_lod"); } - PADDLE_ENFORCE(level0.size() > 1UL, - "The size of target LoD should be greater than 1."); - PADDLE_ENFORCE(level0[0] == 0, - "Target LoD should be a vector starting from 0."); - PADDLE_ENFORCE(level0.back() == in->dims()[0], - "Target LoD should be a vector end with the " - "first dimension of Input(X)."); + PADDLE_ENFORCE_GT(level0.size(), 1UL, + "Size of target LoD should be greater than 1."); + PADDLE_ENFORCE_EQ(level0[0], 0, + "Target LoD should be a vector starting from 0."); + PADDLE_ENFORCE_EQ(level0.back(), in->dims()[0], + "Target LoD should be a vector end with the " + "first dimension of Input(X)."); for (size_t i = 0; i < level0.size() - 1; ++i) { PADDLE_ENFORCE(level0[i + 1] > level0[i], "Target LoD should be an ascending vector."); } - out->ShareDataWith(*in); // cast level0 to size_t std::vector ulevel0(level0.size(), 0); std::transform(level0.begin(), level0.end(), ulevel0.begin(), diff --git a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py index 3bf8230f87..6b6d4c824a 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py +++ b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py @@ -42,7 +42,7 @@ class TestLodResetOpByInput(OpTest): target_lod_0 = [0, 4, 7, 10] self.inputs = { 'X': (x, lod), - 'TargetLoD': np.array([target_lod_0]).astype('int32') + 'Y': np.array([target_lod_0]).astype('int32') } self.outputs = {'Out': (x, [target_lod_0])} @@ -50,7 +50,7 @@ class TestLodResetOpByInput(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + self.check_grad(["X"], "Out", no_grad_set=set("Y")) class TestLodResetOpBoth(OpTest): @@ -62,7 +62,7 @@ class TestLodResetOpBoth(OpTest): target_lod_0_in = [0, 4, 7, 10] self.inputs = { 'X': (x, lod), - 'TargetLoD': np.array(target_lod_0_in).astype('int32') + 'Y': np.array(target_lod_0_in).astype('int32') } self.attrs = {'target_lod': target_lod_0_attr} self.outputs = {'Out': (x, [target_lod_0_in])} @@ -71,7 +71,24 @@ class TestLodResetOpBoth(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + self.check_grad(["X"], "Out", no_grad_set=set("Y")) + + +class TestLodResetOpYIsLoDTensor(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + y = np.random.random((10, 10)).astype("float32") + target_lod_0 = [[0, 4, 7, 10]] + self.inputs = {'X': (x, lod), 'Y': (y, target_lod_0)} + self.outputs = {'Out': (x, target_lod_0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", no_grad_set=set("Y")) if __name__ == '__main__': From 869a6f9cea8ebccda5701009fe61f3b9f684e43d Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Mar 2018 21:49:33 +0800 Subject: [PATCH 2/3] Add python wrapper. --- paddle/fluid/operators/lod_reset_op.cc | 61 ++++++++--- python/paddle/fluid/layers/nn.py | 100 +++++++++++++++++- .../fluid/tests/unittests/test_layers.py | 9 ++ 3 files changed, 155 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 6599e183ef..7d5687f2d0 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -30,7 +30,7 @@ class LoDResetOp : public framework::OperatorWithKernel { if (!ctx->HasInput("Y")) { auto level0 = ctx->Attrs().Get>("target_lod"); PADDLE_ENFORCE_GT(level0.size(), 1, - "If Input(Y) is not provided, the target lod should be " + "If Input(Y) not provided, the target lod should be " "specified by attribute `target_lod`."); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -54,9 +54,10 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { "could be a Tensor or LoDTensor, where the data of output " "variable inherits from."); AddInput("Y", - "(Tensor, LoDTensor, optional) If provided, lod of Input(Y) would " - "be considered as the target lod first, otherwise data of " - "Input(Y) would be considered as the target lod.") + "(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, " + "lod of Input(Y) would be considered as the target lod first, " + "otherwise data of Input(Y) would be considered as the " + "target lod.") .AsDispensable(); AddOutput("Out", "(LoDTensor) Output variable of LoDResetOp which should be a " @@ -67,25 +68,59 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(LoDReset operator Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y` -provided, `Y.lod` would be considered as target LoD first, otherwise `Y.data` -would be considered as target LoD. If `Y` is not provided, target LoD should be -specified by attribute `target_lod`. If target LoD is specified by `Y.data` or -`target_lod`, only one level LoD is supported. +provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD +first, otherwise `Y.data` would be considered as target LoD. If `Y` is not +provided, target LoD should be specified by attribute `target_lod`. +If target LoD is specified by `Y.data` or `target_lod`, only one level LoD +is supported. -An example: +Example 1: -Given a 1-level LoDTensor input(X) - X.lod = [[ 0, 2, 5 6 ]] +Given a 1-level LoDTensor input(X): + X.lod = [[ 0, 2, 5 6 ]] X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] X.dims = [6, 1] -target_lod: [0, 4, 6] +attr(target_lod): [0, 4, 6] -then we get an 1-level LoDTensor +then we get a 1-level LoDTensor: Out.lod = [[ 0, 4, 6 ]] Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] Out.dims = [6, 1] +Example 2: + +Given a 1-level LoDTensor input(X): + X.lod = [[ 0, 2, 5 6 ]] + X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + X.dims = [6, 1] + +input(Y) is a Tensor: + Y.data = [[0, 2, 6]] + Y.dims = [1, 3] + +then we get a 1-level LoDTensor: + Out.lod = [[ 0, 2, 6 ]] + Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + Out.dims = [6, 1] + +Example 3: + +Given a 1-level LoDTensor input(X): + X.lod = [[ 0, 2, 5 6 ]] + X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + X.dims = [6, 1] + +input(Y) is a 2-level LoDTensor: + Y.lod = [[0, 2, 4], [0, 2, 5, 6]] + Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]] + Y.dims = [6, 1] + +then we get a 2-level LoDTensor: + Out.lod = [[0, 2, 4], [0, 2, 5, 6]] + Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + Out.dims = [6, 1] + )DOC"); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bf161d6618..8dced4bbfc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -73,6 +73,7 @@ __all__ = [ 'smooth_l1', 'one_hot', 'autoincreased_step_counter', + 'lod_reset', ] @@ -2225,7 +2226,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): keep_dim (bool|False): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. - name(str|None): A name for this layer(optional). If set None, the + name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. Returns: @@ -2241,7 +2242,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(x) # [0.0002268] fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63] fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084] - fluid.layers.reduce_prod(x, dim=1, + fluid.layers.reduce_prod(x, dim=1, keep_dim=True) # [[0.027], [0.0084]] """ helper = LayerHelper('reduce_prod', **locals()) @@ -3292,3 +3293,98 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): counter.stop_gradient = True return counter + + +def lod_reset(x, y, target_lod=None): + """ + LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or + **target_lod**. When **y** provided, **y.lod** would be considered as target + LoD first, otherwise **y.data** would be considered as target LoD. If **y** + is not provided, target LoD should be specified by **target_lod**. + If target LoD is specified by **Y.data** or **target_lod**, only one level + LoD is supported. + + .. code-block:: text + + * Example 1: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + target_lod: [0, 4, 6] + + then we get a 1-level LoDTensor: + out.lod = [[ 0, 4, 6 ]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + * Example 2: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + y is a Tensor: + y.data = [[0, 2, 6]] + y.dims = [1, 3] + + then we get a 1-level LoDTensor: + out.lod = [[ 0, 2, 6 ]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + * Example 3: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + y is a 2-level LoDTensor: + y.lod = [[0, 2, 4], [0, 2, 5, 6]] + y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]] + y.dims = [6, 1] + + then we get a 2-level LoDTensor: + out.lod = [[0, 2, 4], [0, 2, 5, 6]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + Args: + x (Variable): Input variable which could be a Tensor or LodTensor. + y (Variable|None): If provided, output's LoD would be derived from y. + target_lod (list|tuple|None): One level LoD which should be considered + as target LoD when y not provided. + + Returns: + Variable: Output variable with LoD specified by this operator. + + Raises: + ValueError: If y and target_lod are both None. + + Examples: + .. code-block:: python + + x = layers.data(name='x', shape=[10]) + y = layers.data(name='y', shape=[10, 20], lod_level=2) + out = layers.lod_reset(x=x, y=y) + """ + helper = LayerHelper("lod_reset", **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + if y is not None: + helper.append_op( + type="lod_reset", inputs={'X': x, + 'Y': y}, outputs={'Out': out}) + elif target_lod is not None: + helper.append_op( + type="lod_reset", + inputs={'X': x}, + attrs={'target_lod': target_lod}, + outputs={'Out': out}) + else: + raise ValueError("y and target_lod should not be both None.") + + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 90d70aa39f..744a762ae7 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -327,6 +327,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(loss) print(str(program)) + def test_lod_reset(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[10], dtype='float32') + y = layers.data( + name='y', shape=[10, 20], dtype='float32', lod_level=2) + print(layers.lod_reset(x=x, y=y)) + print(str(program)) + if __name__ == '__main__': unittest.main() From cd11b1bd5ce30c91166bfb131e8f2618703e13c8 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Mar 2018 22:15:20 +0800 Subject: [PATCH 3/3] Set default value of y to None. --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8dced4bbfc..9656dcf94f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3295,7 +3295,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): return counter -def lod_reset(x, y, target_lod=None): +def lod_reset(x, y=None, target_lod=None): """ LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or **target_lod**. When **y** provided, **y.lod** would be considered as target