From 352fa41a16f995b29ca8c8da78a87bee04dc496b Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 15 Mar 2018 12:19:57 +0800 Subject: [PATCH 01/26] Finish adapting forward. --- paddle/fluid/operators/sequence_expand_op.cc | 106 ++++++++++++++----- paddle/fluid/operators/sequence_expand_op.cu | 11 +- paddle/fluid/operators/sequence_expand_op.h | 74 ++++++++----- 3 files changed, 140 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index a5d84d629b..acb6eb82a2 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::Tensor; +using framework::LoDTensor; class SequenceExpandOp : public framework::OperatorWithKernel { public: @@ -25,15 +25,67 @@ class SequenceExpandOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasOutput("Out")); - PADDLE_ENFORCE(ctx->HasInput("Y")); - framework::DDim out_dim; - auto y_dim = ctx->GetInputDim("Y"); - out_dim = ctx->GetInputDim("X"); - out_dim[0] = y_dim[0]; - ctx->ShareLoD("Y", "Out"); - ctx->SetOutputDim("Out", out_dim); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of SequenceExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceExpandOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2U, + "Dimension number of Input(X) should be 2."); + int ref_level = ctx->Attrs().Get("ref_level"); + + if (ctx->IsRuntime()) { + framework::Variable* x_var = + boost::get(ctx->GetInputVarPtrs("X")[0]); + framework::Variable* y_var = + boost::get(ctx->GetInputVarPtrs("Y")[0]); + + auto& x_lod = x_var->Get().lod(); + auto& y_lod = y_var->Get().lod(); + + PADDLE_ENFORCE_LE(x_lod.size(), 1, + "Number of lod level of Input(X) should not be " + "greater than 1."); + + PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0, + "Number of lod level of Input(X) either equal to 0 " + "or equal to that of Input(Y)."); + + int64_t out_first_dim = 0; + if (y_lod[ref_level].size() < 1) { + out_first_dim = x_dims[0]; + } else { + if (x_lod.size() == 1) { // X is LoDTensor + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; + out_first_dim += + (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; + } + } else { // X is normal Tensor + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + out_first_dim += y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + } + } + } + ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]}); + } else { + framework::VarDesc* in_reader = + boost::get(ctx->GetInputVarPtrs("Y")[0]); + int lod_level_num = in_reader->GetLoDLevels().size(); + + PADDLE_ENFORCE_GE(ref_level, 0, + "Level of referred lod should be greater or " + "equal to 0."); + + PADDLE_ENFORCE_LT(ref_level, lod_level_num, + "Level of referred lod should be smaller than " + "level number of Input(Y)."); + + ctx->SetOutputDim("Out", {-1, x_dims[1]}); + } } }; @@ -42,17 +94,15 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(Tensor or LoDTensor) The input(X) of this operator can be a " - "LoDTensor or a base Tensor."); + "(LoDTensor, default LoDTensor) A 2-D LoDTensor whose lod " + "level is at most 1."); AddInput("Y", - "(LoDTensor)The reference input(Y) of sequence_expand op." - "It must be a LoDTensor with k-level(k>0)." - "The input(X) will be expanded according to LOD of input(Y)." - "The element numbers of last level in input(Y) " - "must be equal to dims[0] of input(X)."); + "(LoDTensor, default LoDTensor) Referred LoDTensor whose " + "lod (specified level) is referred by Input(X)."); AddOutput("Out", - "(LodTensor)The output of sequence_expand op." - "The lod of output will be as same as input(Y)'s lod."); + "(LodTensor, default LoDTensor) Output LoDTensor which is " + "generated from Input(X) by referring lod of Input(Y)."); + AddAttr("ref_level", "Specify lod level of Input(Y)."); AddComment(R"DOC( Sequence Expand Operator. @@ -129,12 +179,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("Out")); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "The input(Out@GRAD) should not be null"); + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { ctx->SetOutputDim(x_grad_name, x_dims); } @@ -149,7 +201,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker, sequence_expand_grad, ops::SequenceExpandOpGrad); REGISTER_OP_CPU_KERNEL( sequence_expand, - ops::SequenceExpandKernel); + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel); REGISTER_OP_CPU_KERNEL( sequence_expand_grad, - ops::SequenceExpandGradKernel); + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel); diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index 26622d23af..bb51bb2902 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -18,7 +18,14 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sequence_expand, - ops::SequenceExpandKernel); + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel, + ops::SequenceExpandKernel); REGISTER_OP_CUDA_KERNEL( sequence_expand_grad, - ops::SequenceExpandGradKernel); + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel, + ops::SequenceExpandGradKernel); diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 76dde976db..2b4fa016f7 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -28,33 +28,57 @@ class SequenceExpandKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); - auto* out = context.Output("Out"); - const T* x_data = x->data(); - auto x_dims = x->dims(); auto* y = context.Input("Y"); - PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); - PADDLE_ENFORCE_EQ(static_cast(x_dims[0]), - y->lod().back().size() - 1, - "The size of last lod level in Input(Y)" - "must be equal to dims[0] of Input(X)."); - out->set_lod(y->lod()); - auto* place = - context.template device_context().eigen_device(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); + auto* out = context.Output("Out"); + int ref_level = context.Attr("ref_level"); - for (size_t i = 0; i < out_starts.size() - 1; i++) { - int scale = out_starts[i + 1] - out_starts[i]; - Eigen::TensorMap< - Eigen::Tensor> - x_t(x_data, 1, element_len); - Eigen::TensorMap> - out_t(out_data, scale, element_len); - Eigen::array cast({{scale, 1}}); - out_t.device(*place) = x_t.broadcast(cast); - x_data += element_len; - out_data += element_len * scale; + auto& x_lod = x->lod(); + auto& y_lod = y->lod(); + + PADDLE_ENFORCE_GE(ref_level, 0, + "Value of attribute `ref_level` should be greater or " + "equal to 0."); + + PADDLE_ENFORCE_LT(ref_level, y_lod.size(), + "Value of attribute `ref_level` should be smaller than " + "level number of Y's lod."); + + if (y_lod[ref_level].size() < 1) { + framework::TensorCopy(*x, context.GetPlace(), out); + return; + } + + if (x_lod.size() == 0) { + int out_start = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + auto x_sub_tensor = x->Slice(i - 1, i); + for (size_t j = 0; j < repeat_num; ++j) { + auto out_sub_tensor = out->Slice(out_start, out_start + 1); + framework::TensorCopy(x_sub_tensor, context.GetPlace(), + &out_sub_tensor); + out_start++; + } + } + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].resize(1); + out_lod[0][0] = 0; + int out_idx = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; + auto x_sub_tensor = x->Slice(x_lod[0][i], x_lod[0][i - 1]); + for (size_t j = 0; j < repeat_num; ++j) { + auto out_sub_tensor = + out->Slice(out_lod[0][out_idx], out_lod[0][out_idx] + x_seq_len); + framework::TensorCopy(x_sub_tensor, context.GetPlace(), + &out_sub_tensor); + out_lod[0].push_back(out_lod[0][out_idx] + x_seq_len); + out_idx++; + } + } } } }; From bf3f56e899cd4205c6e3c5cea0c4c1c69819ae84 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 15 Mar 2018 22:33:33 +0800 Subject: [PATCH 02/26] Finish adaption for backward. --- paddle/fluid/operators/math/math_function.cc | 2 + paddle/fluid/operators/math/math_function.cu | 2 + paddle/fluid/operators/sequence_expand_op.cc | 51 ++++--- paddle/fluid/operators/sequence_expand_op.h | 135 +++++++++++-------- 4 files changed, 108 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 35d251f71a..17e576a9d5 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -371,6 +371,8 @@ template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; template struct RowwiseSum; template struct RowwiseSum; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 3abbcdb71d..c6ca2693a0 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -422,6 +422,8 @@ struct RowwiseAdd { template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; // template struct ColwiseSum; // The ColwiseSum failed in debug mode, // and only failed for this case. So reimplemented it. diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index acb6eb82a2..25a8283858 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -33,9 +33,10 @@ class SequenceExpandOp : public framework::OperatorWithKernel { "Output(Out) of SequenceExpandOp should not be null."); auto x_dims = ctx->GetInputDim("X"); + int ref_level = ctx->Attrs().Get("ref_level"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Dimension number of Input(X) should be 2."); - int ref_level = ctx->Attrs().Get("ref_level"); if (ctx->IsRuntime()) { framework::Variable* x_var = @@ -51,39 +52,37 @@ class SequenceExpandOp : public framework::OperatorWithKernel { "greater than 1."); PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0, - "Number of lod level of Input(X) either equal to 0 " - "or equal to that of Input(Y)."); + "Level number of Input(X)'s lod should be either equal " + "to 0 or equal to that of Input(Y)."); + + PADDLE_ENFORCE_GT(y_lod.size(), 0, + "Level number of Input(Y)'s lod should be " + "greater than 0."); + + PADDLE_ENFORCE( + ref_level == -1 || + (ref_level >= 0 && ref_level < static_cast(y_lod.size())), + "Invlid `ref_level`, which should be either equal to -1 " + "or in [0, %d)", + y_lod.size()); + + if (ref_level == -1) ref_level = y_lod.size() - 1; int64_t out_first_dim = 0; - if (y_lod[ref_level].size() < 1) { + if (y_lod[ref_level].size() <= 1) { out_first_dim = x_dims[0]; } else { - if (x_lod.size() == 1) { // X is LoDTensor - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - int x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; - out_first_dim += - (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; - } - } else { // X is normal Tensor - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - out_first_dim += y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int x_seq_len = 1; + if (x_lod.size() == 1) { + x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; } + out_first_dim += + (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; } } ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]}); } else { - framework::VarDesc* in_reader = - boost::get(ctx->GetInputVarPtrs("Y")[0]); - int lod_level_num = in_reader->GetLoDLevels().size(); - - PADDLE_ENFORCE_GE(ref_level, 0, - "Level of referred lod should be greater or " - "equal to 0."); - - PADDLE_ENFORCE_LT(ref_level, lod_level_num, - "Level of referred lod should be smaller than " - "level number of Input(Y)."); - ctx->SetOutputDim("Out", {-1, x_dims[1]}); } } @@ -102,7 +101,7 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(LodTensor, default LoDTensor) Output LoDTensor which is " "generated from Input(X) by referring lod of Input(Y)."); - AddAttr("ref_level", "Specify lod level of Input(Y)."); + AddAttr("ref_level", "Specify lod level of Input(Y).").SetDefault(-1); AddComment(R"DOC( Sequence Expand Operator. diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 2b4fa016f7..8cbfdf177e 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -32,52 +32,53 @@ class SequenceExpandKernel : public framework::OpKernel { auto* out = context.Output("Out"); int ref_level = context.Attr("ref_level"); + out->mutable_data(context.GetPlace()); auto& x_lod = x->lod(); auto& y_lod = y->lod(); - PADDLE_ENFORCE_GE(ref_level, 0, - "Value of attribute `ref_level` should be greater or " - "equal to 0."); + PADDLE_ENFORCE_GT(y_lod.size(), 0, + "Level number of `Y`'s lod should be greater than 0."); - PADDLE_ENFORCE_LT(ref_level, y_lod.size(), - "Value of attribute `ref_level` should be smaller than " - "level number of Y's lod."); + PADDLE_ENFORCE( + ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()), + "Invlid `ref_level`, which should be either equal to -1 " + "or in [0, %d)", + y_lod.size()); - if (y_lod[ref_level].size() < 1) { + if (ref_level == -1) ref_level = y_lod.size() - 1; + + if (y_lod[ref_level].size() <= 1) { framework::TensorCopy(*x, context.GetPlace(), out); return; } - if (x_lod.size() == 0) { - int out_start = 0; - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; - auto x_sub_tensor = x->Slice(i - 1, i); - for (size_t j = 0; j < repeat_num; ++j) { - auto out_sub_tensor = out->Slice(out_start, out_start + 1); - framework::TensorCopy(x_sub_tensor, context.GetPlace(), - &out_sub_tensor); - out_start++; - } - } - } else { - auto& out_lod = *out->mutable_lod(); + auto& out_lod = *out->mutable_lod(); + if (x_lod.size() == 1) { out_lod.resize(1); - out_lod[0].resize(1); - out_lod[0][0] = 0; - int out_idx = 0; - for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { - int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; - int x_seq_len = x_lod[0][i] - x_lod[0][i - 1]; - auto x_sub_tensor = x->Slice(x_lod[0][i], x_lod[0][i - 1]); - for (size_t j = 0; j < repeat_num; ++j) { - auto out_sub_tensor = - out->Slice(out_lod[0][out_idx], out_lod[0][out_idx] + x_seq_len); - framework::TensorCopy(x_sub_tensor, context.GetPlace(), - &out_sub_tensor); - out_lod[0].push_back(out_lod[0][out_idx] + x_seq_len); - out_idx++; + out_lod[0] = {0}; + } + + int out_offset = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_start = i - 1; + int x_end = i; + if (x_lod.size() == 1) { + x_start = x_lod[0][i - 1]; + x_end = x_lod[0][i]; + } + int x_seq_len = x_end - x_start; + auto x_sub_tensor = x->Slice(x_start, x_end); + for (size_t j = 0; j < repeat_num; ++j) { + int out_start = out_offset; + if (x_lod.size() == 1) { + out_start = out_lod[0][out_offset]; + out_lod[0].push_back(x_seq_len); } + auto out_sub_tensor = out->Slice(out_start, out_start + x_seq_len); + framework::TensorCopy(x_sub_tensor, context.GetPlace(), + &out_sub_tensor); + out_offset++; } } } @@ -99,27 +100,49 @@ template class SequenceExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* d_out = context.Input(framework::GradVarName("Out")); + auto* g_out = context.Input(framework::GradVarName("Out")); auto* x = context.Input("X"); - auto* out = context.Input("Out"); - auto* d_x = context.Output(framework::GradVarName("X")); - auto out_last_level = out->lod().back(); - d_x->set_lod(x->lod()); - const T* d_out_data = d_out->data(); - T* d_x_data = d_x->mutable_data(context.GetPlace()); - size_t element_len = d_out->numel() / d_out->dims()[0]; - for (size_t i = 0; i < out_last_level.size() - 1; ++i) { - size_t repeat = out_last_level[i + 1] - out_last_level[i]; - Eigen::TensorMap< - Eigen::Tensor> - d_out_t(d_out_data, static_cast(repeat), element_len); - Eigen::TensorMap> - d_x_t(d_x_data, static_cast(element_len)); - auto place = - context.template device_context().eigen_device(); - d_x_t.device(*place) = d_out_t.sum(Eigen::array({{0}})); - d_out_data += (repeat * element_len); - d_x_data += element_len; + auto* y = context.Input("Y"); + auto* g_x = context.Output(framework::GradVarName("X")); + int ref_level = context.Attr("ref_level"); + + g_x->mutable_data(context.GetPlace()); + g_x->set_lod(x->lod()); + + auto& x_lod = x->lod(); + auto& y_lod = y->lod(); + + if (ref_level == -1) ref_level = y_lod.size() - 1; + + // just copy the gradient + if (y_lod[ref_level].size() <= 1) { + framework::TensorCopy(*g_out, context.GetPlace(), g_x); + return; + } + + auto& dev_ctx = context.template device_context(); + + int g_out_offset = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + if (repeat_num > 0) { + int x_start = i - 1; + int x_end = i; + if (x_lod.size() == 1) { + x_start = x_lod[0][i - 1]; + x_end = x_lod[0][i]; + } + int x_seq_len = x_end - x_start; + auto column = x_seq_len * x->dims()[1]; + auto g_x_sub = g_x->Slice(x_start, x_end); + g_x_sub = framework::ReshapeToMatrix(g_x_sub, column); + int g_out_end = g_out_offset + repeat_num * x_seq_len; + auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); + g_out_sub = framework::ReshapeToMatrix(g_out_sub, column); + math::ColwiseSum col_sum; + col_sum(dev_ctx, g_out_sub, &g_x_sub); + g_out_offset += repeat_num * x_seq_len; + } } } }; From 58730ba131a468df2c8873d41b189d5690be10be Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 16 Mar 2018 19:51:56 +0800 Subject: [PATCH 03/26] Enhance unit test. --- paddle/fluid/operators/sequence_expand_op.cc | 102 +++++++++--------- paddle/fluid/operators/sequence_expand_op.h | 40 ++++--- python/paddle/fluid/layers/nn.py | 49 +++++---- .../fluid/tests/unittests/test_layers.py | 4 +- .../tests/unittests/test_sequence_expand.py | 51 ++++++--- 5 files changed, 145 insertions(+), 101 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index 25a8283858..2c88a53bc7 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -33,10 +33,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel { "Output(Out) of SequenceExpandOp should not be null."); auto x_dims = ctx->GetInputDim("X"); + auto out_dims = x_dims; int ref_level = ctx->Attrs().Get("ref_level"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2U, - "Dimension number of Input(X) should be 2."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Dimension number of Input(X) should be at least 2."); if (ctx->IsRuntime()) { framework::Variable* x_var = @@ -50,15 +51,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_LE(x_lod.size(), 1, "Number of lod level of Input(X) should not be " "greater than 1."); - - PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0, - "Level number of Input(X)'s lod should be either equal " - "to 0 or equal to that of Input(Y)."); - PADDLE_ENFORCE_GT(y_lod.size(), 0, "Level number of Input(Y)'s lod should be " "greater than 0."); - PADDLE_ENFORCE( ref_level == -1 || (ref_level >= 0 && ref_level < static_cast(y_lod.size())), @@ -68,6 +63,14 @@ class SequenceExpandOp : public framework::OperatorWithKernel { if (ref_level == -1) ref_level = y_lod.size() - 1; + if (x_lod.size() > 0) { + PADDLE_ENFORCE( + x_lod.size() == 0 || x_lod[0].size() == y_lod[ref_level].size(), + "Level number of Input(X)'s lod should be 0. Otherwise " + "size of Input(X)'s first level lod should be equal to " + "size of Input(Y)'s lod of referred level."); + } + int64_t out_first_dim = 0; if (y_lod[ref_level].size() <= 1) { out_first_dim = x_dims[0]; @@ -81,9 +84,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel { (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; } } - ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]}); + out_dims[0] = out_first_dim; + ctx->SetOutputDim("Out", out_dims); } else { - ctx->SetOutputDim("Out", {-1, x_dims[1]}); + out_dims[0] = -1; + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } } }; @@ -105,69 +111,69 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Sequence Expand Operator. -This operator expands input(X) according to LOD of input(Y). +This operator expands `X` according to specified level lod of `Y`. Current +implementation constaints that lod level of `X` should be at most 1. Attribute +`ref_level` is used to specify which level lod of `Y` is referred to expand `X`. +If set `ref_level` to -1, then last level lod of `Y` would be referred. +Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X` +would be viewed as a 2-D tensor. + Following are cases to better explain how this works: + Case 1: -Given a 2-level LoDTensor input(X) - X.lod = [[0, 2, 3], - [0, 1, 3, 4]] - X.data = [a, b, c, d] +Given a 1-level LoDTensor input(X) + X.lod = [[0, 2, 4]] + X.data = [[a], [b], [c], [d]] X.dims = [4, 1] and input(Y) Y.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 2-level LoDTensor - Out.lod = [[0, 2, 4], - [0, 3, 6, 7, 8]] - Out.data = [a, a, a, b, b, b, c, d] +ref_level: 0 +then we get 1-level LoDTensor + Out.lod = [[0, 2, 4, 6, 8]] + Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]] Out.dims = [8, 1] Case 2: +Given 1-level LoDTensor input(X) + X.lod = [[0, 1, 4]] + X.data = [[a], [b], [c], [d]] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 2, 4], + [0, 3, 6, 6, 8]] +ref_level: 0 +then we get 1-level LoDTensor + Out.lod = [[0, 2, 5, 8]] + Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]] + Out.dims = [8, 1] + +Case 3: + Given a common Tensor input(X) - X.data = [a, b, c] + X.data = [[a], [b], [c]] X.dims = [3, 1] and input(Y) Y.lod = [[0, 2, 3, 6]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 1-level LoDTensor - Out.lod = [[0, 2, 3, 6]] - Out.data = [a, a, b, c, c, c] +ref_level: -1 +then we a common Tensor + Out.data = [[a], [a], [b], [c], [c], [c]] Out.dims = [6, 1] -Case 3: +Case 4: Given a common Tensor input(X) X.data = [[a, b], [c, d], [e, f]] X.dims = [3, 2] and input(Y) Y.lod = [[0, 2, 3, 6]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 1-level LoDTensor - Out.lod = [[0, 2, 3, 6]] - Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]] +ref_level: 0 +then we get a common LoDTensor + Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]] Out.dims = [6, 2] -Case 4: - -Given 2-level a LoDTensor input(X) - X.lod = [[0, 2, 3], - [0, 1, 3, 4]] - X.data = [a, b, c, d] - X.dims = [4, 1] -and input(Y) - Y.lod = [[0, 2, 4], - [0, 3, 6, 6, 8]] -with condition len(Y.lod[-1]) -1 == X.dims[0] -then we get 2-level LoDTensor - Out.lod = [[0, 2, 4], - [0, 3, 6, 6, 8]] - Out.data = [a, a, a, b, b, b, d, d] - Out.dims = [8, 1] - - )DOC"); } }; diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 8cbfdf177e..eea3cf0440 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -22,6 +22,9 @@ namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; +template +using EigenMatrix = framework::EigenMatrix; template class SequenceExpandKernel : public framework::OpKernel { @@ -30,15 +33,12 @@ class SequenceExpandKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* out = context.Output("Out"); - int ref_level = context.Attr("ref_level"); - out->mutable_data(context.GetPlace()); + int ref_level = context.Attr("ref_level"); auto& x_lod = x->lod(); auto& y_lod = y->lod(); - PADDLE_ENFORCE_GT(y_lod.size(), 0, "Level number of `Y`'s lod should be greater than 0."); - PADDLE_ENFORCE( ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()), "Invlid `ref_level`, which should be either equal to -1 " @@ -47,6 +47,8 @@ class SequenceExpandKernel : public framework::OpKernel { if (ref_level == -1) ref_level = y_lod.size() - 1; + out->mutable_data(context.GetPlace()); + if (y_lod[ref_level].size() <= 1) { framework::TensorCopy(*x, context.GetPlace(), out); return; @@ -59,6 +61,8 @@ class SequenceExpandKernel : public framework::OpKernel { } int out_offset = 0; + auto& eigen_place = + *context.template device_context().eigen_device(); for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; int x_start = i - 1; @@ -68,16 +72,24 @@ class SequenceExpandKernel : public framework::OpKernel { x_end = x_lod[0][i]; } int x_seq_len = x_end - x_start; - auto x_sub_tensor = x->Slice(x_start, x_end); - for (size_t j = 0; j < repeat_num; ++j) { + if (repeat_num > 0) { + auto x_sub_tensor = x->Slice(x_start, x_end); + x_sub_tensor.Resize({1, x_sub_tensor.numel()}); int out_start = out_offset; if (x_lod.size() == 1) { out_start = out_lod[0][out_offset]; - out_lod[0].push_back(x_seq_len); } - auto out_sub_tensor = out->Slice(out_start, out_start + x_seq_len); - framework::TensorCopy(x_sub_tensor, context.GetPlace(), - &out_sub_tensor); + auto out_sub_tensor = + out->Slice(out_start, out_start + x_seq_len * repeat_num); + out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); + EigenMatrix::From(out_sub_tensor).device(eigen_place) = + EigenMatrix::From(x_sub_tensor) + .broadcast(Eigen::array({{repeat_num, 1}})); + } + for (int j = 0; j < repeat_num; ++j) { + if (x_lod.size() == 1) { + out_lod[0].push_back(out_lod[0].back() + x_seq_len); + } out_offset++; } } @@ -122,6 +134,9 @@ class SequenceExpandGradKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, g_x, static_cast(0)); + int g_out_offset = 0; for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; @@ -133,12 +148,11 @@ class SequenceExpandGradKernel : public framework::OpKernel { x_end = x_lod[0][i]; } int x_seq_len = x_end - x_start; - auto column = x_seq_len * x->dims()[1]; auto g_x_sub = g_x->Slice(x_start, x_end); - g_x_sub = framework::ReshapeToMatrix(g_x_sub, column); + g_x_sub.Resize(flatten_to_1d(g_x_sub.dims())); int g_out_end = g_out_offset + repeat_num * x_seq_len; auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); - g_out_sub = framework::ReshapeToMatrix(g_out_sub, column); + g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]}); math::ColwiseSum col_sum; col_sum(dev_ctx, g_out_sub, &g_x_sub); g_out_offset += repeat_num * x_seq_len; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bc2be4cdfe..4e6f76206e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1781,52 +1781,52 @@ def conv2d_transpose(input, return out -def sequence_expand(x, y, name=None): +def sequence_expand(x, y, ref_level=-1, name=None): """Sequence Expand Layer. This layer will expand the input variable **x** - according to LoD information of **y**. And the following examples will - explain how sequence_expand works: + according to specified level lod of **y**. Please note that lod level of + **x** is at most 1 and rank of **x** is at least 2. When rank of **x** + is greater than 2, then it would be viewed as a 2-D tensor. + Following examples will explain how sequence_expand works: .. code-block:: text * Case 1 x is a LoDTensor: - x.lod = [[0, 2, 3], - [0, 1, 3, 4]] - x.data = [a, b, c, d] + x.lod = [[0, 2, 4]] + x.data = [[a], [b], [c], [d]] x.dims = [4, 1] y is a LoDTensor: y.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] - with condition len(y.lod[-1]) - 1 == x.dims[0] + ref_level: 0 - then output is a 2-level LoDTensor: - out.lod = [[0, 2, 4], - [0, 3, 6, 7, 8]] - out.data = [a, a, a, b, b, b, c, d] + then output is a 1-level LoDTensor: + out.lod = [[0, 2, 4, 6, 8]] + out.data = [[a], [b], [a], [b], [c], [d], [c], [d]] out.dims = [8, 1] * Case 2 x is a Tensor: - x.data = [a, b, c] + x.data = [[a], [b], [c]] x.dims = [3, 1] y is a LoDTensor: - y.lod = [[0, 2, 3, 6]] - - with condition len(y.lod[-1]) - 1 == x.dims[0] + y.lod = [[0, 2, 2, 5]] - then output is a 1-level LoDTensor: - out.lod = [[0, 2, 3, 6]] - out.data = [a, a, b, c, c, c] - out.dims = [6, 1] + ref_level: -1 + then output is a Tensor: + out.data = [[a], [a], [c], [c], [c]] + out.dims = [5, 1] Args: x (Variable): The input variable which is a Tensor or LoDTensor. y (Variable): The input variable which is a LoDTensor. + ref_level (int): Lod level of `y` to be referred by `x`. If set to -1, + refer the last level of lod. name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + will be named automatically. Returns: Variable: The expanded variable which is a LoDTensor. @@ -1837,14 +1837,17 @@ def sequence_expand(x, y, name=None): x = fluid.layers.data(name='x', shape=[10], dtype='float32') y = fluid.layers.data(name='y', shape=[10, 20], dtype='float32', lod_level=1) - out = layers.sequence_expand(x=x, y=y) + out = layers.sequence_expand(x=x, y=y, ref_level=0) """ helper = LayerHelper('sequence_expand', input=x, **locals()) dtype = helper.input_dtype() tmp = helper.create_tmp_variable(dtype) helper.append_op( - type='sequence_expand', inputs={'X': x, - 'Y': y}, outputs={'Out': tmp}) + type='sequence_expand', + inputs={'X': x, + 'Y': y}, + outputs={'Out': tmp}, + attrs={'ref_level': ref_level}) return tmp diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 6944cca394..e56d78ae8b 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -181,8 +181,8 @@ class TestBook(unittest.TestCase): with program_guard(program): x = layers.data(name='x', shape=[10], dtype='float32') y = layers.data( - name='y', shape=[10, 20], dtype='float32', lod_level=1) - self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) + name='y', shape=[10, 20], dtype='float32', lod_level=2) + self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1)) print(str(program)) def test_lstm_unit(self): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index 957fa5d2c4..7feb509c4d 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest): def compute(self): x = self.inputs['X'] x_data, x_lod = x if type(x) == tuple else (x, None) - n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0]) y_data, y_lod = self.inputs['Y'] - repeats = [((y_lod[-1][i + 1] - y_lod[-1][i])) - for i in range(len(y_lod[-1]) - 1)] - out = x_data.repeat(repeats, axis=0) - self.outputs = {'Out': out} + + if hasattr(self, 'attrs'): + ref_level = self.attrs['ref_level'] + else: + ref_level = len(y_lod) - 1 + + out = np.zeros(shape=((0, ) + x_data.shape[1:]), dtype=x_data.dtype) + + if x_lod is None: + x_idx = [i for i in xrange(x_data.shape[0] + 1)] + else: + x_idx = x_lod[0] + out_lod = [[0]] + + for i in xrange(1, len(y_lod[ref_level])): + repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1] + x_len = x_idx[i] - x_idx[i - 1] + if repeat_num > 0: + x_sub = x_data[x_idx[i - 1]:x_idx[i], :] + x_sub = np.repeat(x_sub, repeat_num, axis=0) + out = np.vstack((out, x_sub)) + if x_lod is not None: + for j in xrange(repeat_num): + out_lod[0].append(out_lod[0][-1] + x_len) + + if x_lod is None: + self.outputs = {'Out': out} + else: + self.outputs = {'Out': (out, out_lod)} def setUp(self): self.op_type = 'sequence_expand' @@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand): x_lod = [[0, 2, 5]] y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] - self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} + self.attrs = {'ref_level': 0} class TestSequenceExpandCase2(TestSequenceExpand): @@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand): x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') x_lod = [[0, 1]] y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') - y_lod = [[0, 2]] + y_lod = [[0, 2], [0, 2]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + self.attrs = {'ref_level': 0} class TestSequenceExpandCase3(TestSequenceExpand): @@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand): def set_data(self): - x_data = np.array( - [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( - [2, 5]).astype('float32') - x_lod = [[ - 0, - 1, - 2, - ]] + data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3] + x_data = np.array(data).reshape([5, 2]).astype('float32') + x_lod = [[0, 2, 5]] y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') y_lod = [[0, 1, 2], [0, 1, 2]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} From 3b03e3748d72bfb6a7d741bdb20d1f211d0825c8 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 16 Mar 2018 20:05:28 +0800 Subject: [PATCH 04/26] Refine some ENFORCE. --- paddle/fluid/operators/sequence_expand_op.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index 2c88a53bc7..d4bf6034ed 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -49,7 +49,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel { auto& y_lod = y_var->Get().lod(); PADDLE_ENFORCE_LE(x_lod.size(), 1, - "Number of lod level of Input(X) should not be " + "Level number of Input(X)'s lod should not be " "greater than 1."); PADDLE_ENFORCE_GT(y_lod.size(), 0, "Level number of Input(Y)'s lod should be " @@ -64,11 +64,10 @@ class SequenceExpandOp : public framework::OperatorWithKernel { if (ref_level == -1) ref_level = y_lod.size() - 1; if (x_lod.size() > 0) { - PADDLE_ENFORCE( - x_lod.size() == 0 || x_lod[0].size() == y_lod[ref_level].size(), - "Level number of Input(X)'s lod should be 0. Otherwise " - "size of Input(X)'s first level lod should be equal to " - "size of Input(Y)'s lod of referred level."); + PADDLE_ENFORCE(x_lod[0].size() == y_lod[ref_level].size(), + "Level number of Input(X)'s lod could be 0. Otherwise " + "size of Input(X)'s first level lod should be equal to " + "size of Input(Y)'s referred level lod."); } int64_t out_first_dim = 0; From ab3543e35ee84ebbf9fe8c11eda7318f01ab7515 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Fri, 16 Mar 2018 14:29:49 -0700 Subject: [PATCH 05/26] Fix compilation for gcc5.4 The error is: paddle/fluid/operators/math/concat.cc:47:72: error: invalid initialization of non-const reference of type 'paddle::platform::CPUPlace&' from an rvalue of type 'paddle::platform::CPUPlace' auto& cpu_place = boost::get(context.GetPlace()); Should not use reference for cpu_place. --- paddle/fluid/operators/math/concat.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index b542143419..b672c79afd 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -44,7 +44,7 @@ class ConcatFunctor { out_cols += t_cols; input_cols[i] = t_cols; } - auto& cpu_place = boost::get(context.GetPlace()); + auto cpu_place = boost::get(context.GetPlace()); // computation for (int k = 0; k < out_rows; ++k) { @@ -87,7 +87,7 @@ class ConcatGradFunctor { input_cols += t_cols; output_cols[i] = t_cols; } - auto& cpu_place = boost::get(context.GetPlace()); + auto cpu_place = boost::get(context.GetPlace()); // computation for (int k = 0; k < input_rows; ++k) { From a571ef382ee45b1e255f44c9650510f20403f4bf Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 19 Mar 2018 14:11:40 +0800 Subject: [PATCH 06/26] fix bugs --- .../fluid/operators/reader/create_double_buffer_reader_op.cc | 5 +++-- paddle/fluid/operators/reader/create_shuffle_reader_op.cc | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index d0de092947..8960fe5d63 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -159,11 +159,12 @@ void DoubleBufferReader::PrefetchThreadFunc() { if (!buffer_->Send(&batch)) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " - "prefetch thread terminates."; - break; + "prefetch thread terminate."; + return; } } buffer_->Close(); + VLOG(5) << "Prefetch thread terminates."; } bool DoubleBufferReader::HasNext() const { diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 70e2f587dc..4ebef4aed7 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -50,7 +50,6 @@ class ShuffleReader : public framework::DecoratedReader { buffer_.clear(); buffer_.reserve(buffer_size_); iteration_pos_ = 0; - PADDLE_ENFORCE(reader_->HasNext()); for (size_t i = 0; i < buffer_size_; ++i) { if (!reader_->HasNext()) { break; From 07d38a9b9affdb68cab9fb2376cfad7f32a73fce Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 19 Mar 2018 14:32:57 +0800 Subject: [PATCH 07/26] refine patch --- paddle/fluid/operators/reader/create_shuffle_reader_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 4ebef4aed7..3a1f3805a0 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader { } void ReadNext(std::vector* out) override { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } if (iteration_pos_ >= buffer_.size()) { VLOG(10) << "Resetting shuffle buffer"; ReadIntoBuffers(); From 332b665fc789efb7249ee9791af714842ea68e66 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Mar 2018 17:56:12 +0800 Subject: [PATCH 08/26] Enhanced cpp implementation and unit test. --- paddle/fluid/operators/lod_reset_op.cc | 79 +++++++++++-------- paddle/fluid/operators/lod_reset_op.cu | 8 +- paddle/fluid/operators/lod_reset_op.h | 43 ++++++---- .../tests/unittests/test_lod_reset_op.py | 25 +++++- 4 files changed, 101 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 6a66297cb8..6599e183ef 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - // input check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LoDResetOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of LoDResetOp should not be null."); - // If target LoD is not set form Input(), then it must be set from Attr(). - if (!ctx->HasInput("TargetLoD")) { + + if (!ctx->HasInput("Y")) { auto level0 = ctx->Attrs().Get>("target_lod"); - PADDLE_ENFORCE(level0.size() > 1, - "Target LoD is not found, should be set to be a valid one " - "through Input() or Attr()."); + PADDLE_ENFORCE_GT(level0.size(), 1, + "If Input(Y) is not provided, the target lod should be " + "specified by attribute `target_lod`."); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); } @@ -50,36 +49,42 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { public: LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(LoDTensor) The input tensor of lod_reset operator."); - AddInput("TargetLoD", - "(Tensor, optional) The target level 0 LoD from Input().") + AddInput("X", + "(Tensor, LoDTensor) Input variable of LoDResetOp which " + "could be a Tensor or LoDTensor, where the data of output " + "variable inherits from."); + AddInput("Y", + "(Tensor, LoDTensor, optional) If provided, lod of Input(Y) would " + "be considered as the target lod first, otherwise data of " + "Input(Y) would be considered as the target lod.") .AsDispensable(); - AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator."); + AddOutput("Out", + "(LoDTensor) Output variable of LoDResetOp which should be a " + "LoDTensor."); AddAttr>("target_lod", "The target level 0 LoD from Attr().") .SetDefault(std::vector{}); AddComment(R"DOC(LoDReset operator -Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or -Attr(target_lod), or set LoD for Input(X) if it doesn't have one. -Currently the lod_reset operator only supports the reset of level 0 LoD. -At least one of Input(TargetLoD) and Attr(target_lod) must be set, -and if both of them are set, Input(TargetLoD) will be chosen as the -target LoD. +Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y` +provided, `Y.lod` would be considered as target LoD first, otherwise `Y.data` +would be considered as target LoD. If `Y` is not provided, target LoD should be +specified by attribute `target_lod`. If target LoD is specified by `Y.data` or +`target_lod`, only one level LoD is supported. An example: -Given a float LoDTensor X with shape (6, 1), its transpose form represents - - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], -with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like +Given a 1-level LoDTensor input(X) + X.lod = [[ 0, 2, 5 6 ]] + X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + X.dims = [6, 1] - [1.0, 2.0], [3.0, 4.0, 5.0], [6.0]. +target_lod: [0, 4, 6] -If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and -the sequences that the LoDTensor Output(Out) contains becomes: - - [1.0, 2.0, 3.0, 4.0], [5.0, 6.0]. +then we get an 1-level LoDTensor + Out.lod = [[ 0, 4, 6 ]] + Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + Out.dims = [6, 1] )DOC"); } @@ -90,10 +95,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LoDResetGradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + "Input(Out@Grad) of LoDResetGradOp should not be null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } } protected: @@ -111,9 +122,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad, ops::LoDResetGradOp); -REGISTER_OP_CPU_KERNEL(lod_reset, - ops::LoDResetKernel, - ops::LoDResetKernel); +REGISTER_OP_CPU_KERNEL( + lod_reset, ops::LoDResetKernel, + ops::LoDResetKernel, + ops::LoDResetKernel, + ops::LoDResetKernel); REGISTER_OP_CPU_KERNEL( lod_reset_grad, ops::LoDResetGradKernel, - ops::LoDResetGradKernel); + ops::LoDResetGradKernel, + ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/fluid/operators/lod_reset_op.cu b/paddle/fluid/operators/lod_reset_op.cu index b0e87a851a..888d4c12eb 100644 --- a/paddle/fluid/operators/lod_reset_op.cu +++ b/paddle/fluid/operators/lod_reset_op.cu @@ -18,8 +18,12 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( lod_reset, ops::LoDResetKernel, - ops::LoDResetKernel); + ops::LoDResetKernel, + ops::LoDResetKernel, + ops::LoDResetKernel); REGISTER_OP_CUDA_KERNEL( lod_reset_grad, ops::LoDResetGradKernel, - ops::LoDResetGradKernel); + ops::LoDResetGradKernel, + ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index 8186d4f826..99f01c2a25 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -26,35 +26,46 @@ class LoDResetKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); - auto* lod_t = ctx.Input("TargetLoD"); + auto* lod_t = ctx.Input("Y"); + + out->ShareDataWith(*in); std::vector level0; if (lod_t) { - auto* lod = lod_t->data(); - if (platform::is_gpu_place(ctx.GetPlace())) { - framework::Tensor lod_cpu; - framework::TensorCopy(*lod_t, platform::CPUPlace(), - ctx.device_context(), &lod_cpu); - lod = lod_cpu.data(); + if (lod_t->lod().size() > 0) { + auto y_lod = lod_t->lod(); + auto last_level = y_lod[y_lod.size() - 1]; + PADDLE_ENFORCE_EQ(last_level.back(), in->dims()[0], + "Last value of `Y`'s last level LoD should be equal " + "to the first dimension of `X`"); + out->set_lod(y_lod); + return; // early return, since lod already set + } else { + auto* lod = lod_t->data(); + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor lod_cpu; + framework::TensorCopy(*lod_t, platform::CPUPlace(), + ctx.device_context(), &lod_cpu); + lod = lod_cpu.data(); + } + level0 = std::vector(lod, lod + lod_t->numel()); } - level0 = std::vector(lod, lod + lod_t->numel()); } else { level0 = ctx.Attr>("target_lod"); } - PADDLE_ENFORCE(level0.size() > 1UL, - "The size of target LoD should be greater than 1."); - PADDLE_ENFORCE(level0[0] == 0, - "Target LoD should be a vector starting from 0."); - PADDLE_ENFORCE(level0.back() == in->dims()[0], - "Target LoD should be a vector end with the " - "first dimension of Input(X)."); + PADDLE_ENFORCE_GT(level0.size(), 1UL, + "Size of target LoD should be greater than 1."); + PADDLE_ENFORCE_EQ(level0[0], 0, + "Target LoD should be a vector starting from 0."); + PADDLE_ENFORCE_EQ(level0.back(), in->dims()[0], + "Target LoD should be a vector end with the " + "first dimension of Input(X)."); for (size_t i = 0; i < level0.size() - 1; ++i) { PADDLE_ENFORCE(level0[i + 1] > level0[i], "Target LoD should be an ascending vector."); } - out->ShareDataWith(*in); // cast level0 to size_t std::vector ulevel0(level0.size(), 0); std::transform(level0.begin(), level0.end(), ulevel0.begin(), diff --git a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py index 3bf8230f87..6b6d4c824a 100644 --- a/python/paddle/fluid/tests/unittests/test_lod_reset_op.py +++ b/python/paddle/fluid/tests/unittests/test_lod_reset_op.py @@ -42,7 +42,7 @@ class TestLodResetOpByInput(OpTest): target_lod_0 = [0, 4, 7, 10] self.inputs = { 'X': (x, lod), - 'TargetLoD': np.array([target_lod_0]).astype('int32') + 'Y': np.array([target_lod_0]).astype('int32') } self.outputs = {'Out': (x, [target_lod_0])} @@ -50,7 +50,7 @@ class TestLodResetOpByInput(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + self.check_grad(["X"], "Out", no_grad_set=set("Y")) class TestLodResetOpBoth(OpTest): @@ -62,7 +62,7 @@ class TestLodResetOpBoth(OpTest): target_lod_0_in = [0, 4, 7, 10] self.inputs = { 'X': (x, lod), - 'TargetLoD': np.array(target_lod_0_in).astype('int32') + 'Y': np.array(target_lod_0_in).astype('int32') } self.attrs = {'target_lod': target_lod_0_attr} self.outputs = {'Out': (x, [target_lod_0_in])} @@ -71,7 +71,24 @@ class TestLodResetOpBoth(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + self.check_grad(["X"], "Out", no_grad_set=set("Y")) + + +class TestLodResetOpYIsLoDTensor(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + y = np.random.random((10, 10)).astype("float32") + target_lod_0 = [[0, 4, 7, 10]] + self.inputs = {'X': (x, lod), 'Y': (y, target_lod_0)} + self.outputs = {'Out': (x, target_lod_0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", no_grad_set=set("Y")) if __name__ == '__main__': From 869a6f9cea8ebccda5701009fe61f3b9f684e43d Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Mar 2018 21:49:33 +0800 Subject: [PATCH 09/26] Add python wrapper. --- paddle/fluid/operators/lod_reset_op.cc | 61 ++++++++--- python/paddle/fluid/layers/nn.py | 100 +++++++++++++++++- .../fluid/tests/unittests/test_layers.py | 9 ++ 3 files changed, 155 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 6599e183ef..7d5687f2d0 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -30,7 +30,7 @@ class LoDResetOp : public framework::OperatorWithKernel { if (!ctx->HasInput("Y")) { auto level0 = ctx->Attrs().Get>("target_lod"); PADDLE_ENFORCE_GT(level0.size(), 1, - "If Input(Y) is not provided, the target lod should be " + "If Input(Y) not provided, the target lod should be " "specified by attribute `target_lod`."); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -54,9 +54,10 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { "could be a Tensor or LoDTensor, where the data of output " "variable inherits from."); AddInput("Y", - "(Tensor, LoDTensor, optional) If provided, lod of Input(Y) would " - "be considered as the target lod first, otherwise data of " - "Input(Y) would be considered as the target lod.") + "(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, " + "lod of Input(Y) would be considered as the target lod first, " + "otherwise data of Input(Y) would be considered as the " + "target lod.") .AsDispensable(); AddOutput("Out", "(LoDTensor) Output variable of LoDResetOp which should be a " @@ -67,25 +68,59 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(LoDReset operator Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y` -provided, `Y.lod` would be considered as target LoD first, otherwise `Y.data` -would be considered as target LoD. If `Y` is not provided, target LoD should be -specified by attribute `target_lod`. If target LoD is specified by `Y.data` or -`target_lod`, only one level LoD is supported. +provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD +first, otherwise `Y.data` would be considered as target LoD. If `Y` is not +provided, target LoD should be specified by attribute `target_lod`. +If target LoD is specified by `Y.data` or `target_lod`, only one level LoD +is supported. -An example: +Example 1: -Given a 1-level LoDTensor input(X) - X.lod = [[ 0, 2, 5 6 ]] +Given a 1-level LoDTensor input(X): + X.lod = [[ 0, 2, 5 6 ]] X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] X.dims = [6, 1] -target_lod: [0, 4, 6] +attr(target_lod): [0, 4, 6] -then we get an 1-level LoDTensor +then we get a 1-level LoDTensor: Out.lod = [[ 0, 4, 6 ]] Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] Out.dims = [6, 1] +Example 2: + +Given a 1-level LoDTensor input(X): + X.lod = [[ 0, 2, 5 6 ]] + X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + X.dims = [6, 1] + +input(Y) is a Tensor: + Y.data = [[0, 2, 6]] + Y.dims = [1, 3] + +then we get a 1-level LoDTensor: + Out.lod = [[ 0, 2, 6 ]] + Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + Out.dims = [6, 1] + +Example 3: + +Given a 1-level LoDTensor input(X): + X.lod = [[ 0, 2, 5 6 ]] + X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + X.dims = [6, 1] + +input(Y) is a 2-level LoDTensor: + Y.lod = [[0, 2, 4], [0, 2, 5, 6]] + Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]] + Y.dims = [6, 1] + +then we get a 2-level LoDTensor: + Out.lod = [[0, 2, 4], [0, 2, 5, 6]] + Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + Out.dims = [6, 1] + )DOC"); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bf161d6618..8dced4bbfc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -73,6 +73,7 @@ __all__ = [ 'smooth_l1', 'one_hot', 'autoincreased_step_counter', + 'lod_reset', ] @@ -2225,7 +2226,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): keep_dim (bool|False): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. - name(str|None): A name for this layer(optional). If set None, the + name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. Returns: @@ -2241,7 +2242,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(x) # [0.0002268] fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63] fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084] - fluid.layers.reduce_prod(x, dim=1, + fluid.layers.reduce_prod(x, dim=1, keep_dim=True) # [[0.027], [0.0084]] """ helper = LayerHelper('reduce_prod', **locals()) @@ -3292,3 +3293,98 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): counter.stop_gradient = True return counter + + +def lod_reset(x, y, target_lod=None): + """ + LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or + **target_lod**. When **y** provided, **y.lod** would be considered as target + LoD first, otherwise **y.data** would be considered as target LoD. If **y** + is not provided, target LoD should be specified by **target_lod**. + If target LoD is specified by **Y.data** or **target_lod**, only one level + LoD is supported. + + .. code-block:: text + + * Example 1: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + target_lod: [0, 4, 6] + + then we get a 1-level LoDTensor: + out.lod = [[ 0, 4, 6 ]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + * Example 2: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + y is a Tensor: + y.data = [[0, 2, 6]] + y.dims = [1, 3] + + then we get a 1-level LoDTensor: + out.lod = [[ 0, 2, 6 ]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + * Example 3: + + Given a 1-level LoDTensor x: + x.lod = [[ 0, 2, 5 6 ]] + x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + x.dims = [6, 1] + + y is a 2-level LoDTensor: + y.lod = [[0, 2, 4], [0, 2, 5, 6]] + y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]] + y.dims = [6, 1] + + then we get a 2-level LoDTensor: + out.lod = [[0, 2, 4], [0, 2, 5, 6]] + out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + out.dims = [6, 1] + + Args: + x (Variable): Input variable which could be a Tensor or LodTensor. + y (Variable|None): If provided, output's LoD would be derived from y. + target_lod (list|tuple|None): One level LoD which should be considered + as target LoD when y not provided. + + Returns: + Variable: Output variable with LoD specified by this operator. + + Raises: + ValueError: If y and target_lod are both None. + + Examples: + .. code-block:: python + + x = layers.data(name='x', shape=[10]) + y = layers.data(name='y', shape=[10, 20], lod_level=2) + out = layers.lod_reset(x=x, y=y) + """ + helper = LayerHelper("lod_reset", **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + if y is not None: + helper.append_op( + type="lod_reset", inputs={'X': x, + 'Y': y}, outputs={'Out': out}) + elif target_lod is not None: + helper.append_op( + type="lod_reset", + inputs={'X': x}, + attrs={'target_lod': target_lod}, + outputs={'Out': out}) + else: + raise ValueError("y and target_lod should not be both None.") + + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 90d70aa39f..744a762ae7 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -327,6 +327,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(loss) print(str(program)) + def test_lod_reset(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[10], dtype='float32') + y = layers.data( + name='y', shape=[10, 20], dtype='float32', lod_level=2) + print(layers.lod_reset(x=x, y=y)) + print(str(program)) + if __name__ == '__main__': unittest.main() From cd11b1bd5ce30c91166bfb131e8f2618703e13c8 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Mar 2018 22:15:20 +0800 Subject: [PATCH 10/26] Set default value of y to None. --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8dced4bbfc..9656dcf94f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3295,7 +3295,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): return counter -def lod_reset(x, y, target_lod=None): +def lod_reset(x, y=None, target_lod=None): """ LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or **target_lod**. When **y** provided, **y.lod** would be considered as target From 72847ad031cda087d28f806a830f7d5f5a785b63 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 19 Mar 2018 22:45:55 +0800 Subject: [PATCH 11/26] Add python API for Adadelta optimizer. --- python/paddle/fluid/optimizer.py | 57 +++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 1c12d53e4f..d104cc5cbd 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -24,7 +24,9 @@ from layer_helper import LayerHelper from regularizer import append_regularization_ops from clip import append_gradient_clip_ops, error_clip_callback -__all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] +__all__ = [ + 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Adadelta' +] class Optimizer(object): @@ -575,6 +577,58 @@ class DecayedAdagradOptimizer(Optimizer): return decayed_adagrad_op +class AdadeltaOptimizer(Optimizer): + """Simple Adadelta optimizer with average squared grad state and + average squared update state. + """ + _avg_squared_grad_acc_str = "_avg_squared_grad" + _avg_squared_update_acc_str = "_avg_squared_update" + + def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs): + assert learning_rate is not None + assert epsilon is not None + assert rho is not None + super(AdadeltaOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) + self.type = "adadelta" + self._epsilon = epsilon + self._rho = rho + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._avg_squared_grad_acc_str, p) + self._add_accumulator(self._avg_squared_update_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + avg_squared_grad_acc = self._get_accumulator( + self._avg_squared_grad_acc_str, param_and_grad[0]) + avg_squared_update_acc = self._get_accumulator( + self._avg_squared_update_acc_str, param_and_grad[0]) + + # Create the adadelta optimizer op + adadelta_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "AvgSquaredGrad": avg_squared_grad_acc, + "AvgSquaredUpdate": avg_squared_update_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "AvgSquaredGradOut": avg_squared_grad_acc, + "AvgSquaredUpdateOut": avg_squared_update_acc + }, + attrs={"epsilon": self._epsilon, + "rho": self._rho}) + + return adadelta_op + + # We short the class name, since users will use the optimizer with the package # name. The sample code: # @@ -589,3 +643,4 @@ Adagrad = AdagradOptimizer Adam = AdamOptimizer Adamax = AdamaxOptimizer DecayedAdagrad = DecayedAdagradOptimizer +Adadelta = AdadeltaOptimizer From 35c373db113fbc782ca71225d2cb4d464e995194 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Mon, 19 Mar 2018 09:32:35 -0700 Subject: [PATCH 12/26] Support copy in Fluid channels (#9138) * Support copy in Fluid channels * Address PR review comments --- python/paddle/fluid/concurrency.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index 535e881c42..0fc4981a8e 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0): return channel -def channel_send(channel, value): +def channel_send(channel, value, copy=False): """ Sends a value through a channel variable. Used by an unbuffered or buffered channel to pass data from within or to a concurrent Go block, where @@ -141,6 +141,8 @@ def channel_send(channel, value): channel (Variable|Channel): Channel variable created using `make_channel`. value (Variable): Value to send to channel + copy (bool): Copy data while channel send. If False, then data + is moved. The input cannot be used after move. Returns: Variable: The boolean status on whether or not the channel successfully sent the passed value. @@ -162,11 +164,26 @@ def channel_send(channel, value): type=core.VarDesc.VarType.LOD_TENSOR, dtype=core.VarDesc.VarType.BOOL) + X = value + + if copy is True: + copied_X = helper.create_variable( + name=unique_name.generate(value.name + '_copy'), + type=value.type, + dtype=value.dtype, + shape=value.shape, + lod_level=value.lod_level, + capacity=value.capacity) + + assign_op = channel_send_block.append_op( + type="assign_op", inputs={"X": value}, outputs={"Out": copied_X}) + X = copied_X + channel_send_op = channel_send_block.append_op( type="channel_send", inputs={ "Channel": channel, - "X": value, + "X": X, }, outputs={"Status": status}) From 597ba3f3f25b44cb645e3a4b230239aed3749657 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 20 Mar 2018 01:08:10 +0800 Subject: [PATCH 13/26] add more times close test (#9215) --- paddle/fluid/framework/channel_test.cc | 64 ++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index edfb41c724..73be5cdbe2 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -871,3 +871,67 @@ TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) { ch->Reset(0); ChannelHolderDestroyUnblockSenders(ch, false); } + +// This tests that closing a channelholder many times. +void ChannelHolderManyTimesClose(ChannelHolder *ch) { + const int num_threads = 15; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to send data to channel. + for (size_t i = 0; i < num_threads / 3; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *ended) { + int data = 10; + ch->Send(&data); + *ended = true; + }, + &thread_ended[i]); + } + + // Launches threads that try to receive data to channel. + for (size_t i = num_threads / 3; i < 2 * num_threads / 3; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data; + if (ch->Receive(&data)) { + EXPECT_EQ(data, 10); + } + *p = true; + }, + &thread_ended[i]); + } + + // Launches threads that try to close the channel. + for (size_t i = 2 * num_threads / 3; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + if (!ch->IsClosed()) { + ch->close(); + } + *p = true; + }, + &thread_ended[i]); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + + // Verify that all threads are unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + EXPECT_TRUE(ch->IsClosed()); + // delete the channel + delete ch; + for (size_t i = 0; i < num_threads; i++) t[i].join(); +} + +TEST(ChannelHolder, ChannelHolderManyTimesCloseTest) { + // Check for Buffered Channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(10); + ChannelHolderManyTimesClose(ch); +} From 9eae086e392b90d1bc6ea81c9ed69b88bceb86df Mon Sep 17 00:00:00 2001 From: Xi Chen Date: Mon, 19 Mar 2018 10:45:46 -0700 Subject: [PATCH 14/26] add math_function to softmax's dep list --- paddle/fluid/operators/math/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index fba1612d10..547d081006 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -43,7 +43,7 @@ math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) -math_library(softmax) +math_library(softmax DEPS math_function) math_library(unpooling) math_library(vol2col) From 05ad15832aba64097759f8b7f232beba58cabedb Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 11:09:03 -0700 Subject: [PATCH 15/26] initial commit --- paddle/fluid/operators/dropout_op.cu | 15 ++++++----- .../fluid/tests/unittests/test_dropout_op.py | 26 +++++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index d6f9c04359..c949968a74 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -51,7 +52,7 @@ class GPUDropoutKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); - AttrType dropout_prob = context.Attr("dropout_prob"); + AttrType dropout_prob = context.Attr("dropout_prob")); auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); @@ -74,7 +75,7 @@ class GPUDropoutKernel : public framework::OpKernel { context.cuda_device_context().stream()>>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { - Y.device(place) = X * (1.0f - dropout_prob); + Y.device(place) = X * static_cast(1.0f - dropout_prob); } } }; @@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - dropout, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL( - dropout_grad, - ops::DropoutGradKernel); + dropout, ops::GPUDropoutKernel, + ops::GPUDropoutKernel); +REGISTER_OP_CUDA_KERNEL(dropout_grad, + ops::DropoutGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 60930a612c..6fcd5ac1a6 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -82,5 +82,31 @@ class TestDropoutOp5(OpTest): self.check_output() +class TestFP16DropoutOp1(OpTest): + def setUp(self): + x = np.random.random((32, 64)).astype("float16") + self.op_type = "dropout" + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'dropout_prob': 0.35, 'fix_seed': True, 'is_test': True} + self.outputs = {'Out': x * (1.0 - self.attrs['dropout_prob'])} + + def test_check_output(self): + if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): + self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) + + +class TestFP16DropoutOp2(OpTest): + def setUp(self): + x = np.random.random((32, 64, 3)).astype("float16") + self.op_type = "dropout" + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'dropout_prob': 0.75, 'is_test': True} + self.outputs = {'Out': x * (1.0 - self.attrs['dropout_prob'])} + + def test_check_output(self): + if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): + self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) + + if __name__ == '__main__': unittest.main() From d03dbb97f9f5de78f9edafff4608d829a415dc57 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 13:06:31 -0700 Subject: [PATCH 16/26] remove AttrType --- paddle/fluid/operators/dropout_op.cc | 9 +++------ paddle/fluid/operators/dropout_op.cu | 18 +++++++++--------- paddle/fluid/operators/dropout_op.h | 2 +- .../fluid/tests/unittests/test_dropout_op.py | 1 + 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 1074ed6acc..e4436549f6 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel { } }; -template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,7 +72,6 @@ are set equal to their corresponding inputs. } }; -template class DropoutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, - ops::DropoutOpGrad); +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, - ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index c949968a74..f6c85a2a53 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -23,13 +23,13 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template __global__ void RandomGenerator(const size_t n, const int seed, - const AttrType dropout_prob, const T* src, + const float dropout_prob, const T* src, T* mask_data, T* dst) { thrust::minstd_rand rng; rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); + thrust::uniform_real_distribution dist(0, 1); int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < n; idx += blockDim.x * gridDim.x) { @@ -45,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed, // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. -template +template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); - AttrType dropout_prob = context.Attr("dropout_prob")); + float dropout_prob = context.Attr("dropout_prob"); auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); @@ -71,8 +71,8 @@ class GPUDropoutKernel : public framework::OpKernel { int threads = 512; int grid = (x->numel() + threads - 1) / threads; - RandomGenerator<<>>( + RandomGenerator< + T><<>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); @@ -86,7 +86,7 @@ class GPUDropoutKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel); + dropout, ops::GPUDropoutKernel, + ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL(dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 209e4dec17..b5ee86ae2d 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -25,7 +25,7 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 6fcd5ac1a6..5e2c460c41 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest From 18d616ed70ffe4477751770d8c55780751d76f44 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 14:48:15 -0700 Subject: [PATCH 17/26] add float16 arithmetic operators on new GPU --- paddle/fluid/platform/float16.h | 75 ++++++++++++++++++- .../fluid/tests/unittests/test_dropout_op.py | 14 +++- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 52fb8c2531..a68dcc38ac 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -483,8 +483,77 @@ DEVICE inline bool operator>=(const half& a, const half& b) { #endif // PADDLE_CUDA_FP16 -// Arithmetic operators on ARMv8.2-A CPU -#if defined(PADDLE_WITH_NATIVE_FP16) +// Arithmetic operators for float16 on GPU +#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 +DEVICE inline float16 operator+(const float16& a, const float16& b) { + return float16(__hadd(half(a), half(b))); +} + +DEVICE inline float16 operator-(const float16& a, const float16& b) { + return float16(__hsub(half(a), half(b))); +} + +DEVICE inline float16 operator*(const float16& a, const float16& b) { + return float16(__hmul(half(a), half(b))); +} + +DEVICE inline float16 operator/(const float16& a, const float16& b) { + // TODO(kexinzhao): check the cuda version that starts to support __hdiv + float num = __half2float(half(a)); + float denom = __half2float(half(b)); + return float16(num / denom); +} + +DEVICE inline float16 operator-(const float16& a) { + return float16(__hneg(half(a))); +} + +DEVICE inline float16& operator+=(float16& a, const float16& b) { + a = a + b; + return a; +} + +DEVICE inline float16& operator-=(float16& a, const float16& b) { + a = a - b; + return a; +} + +DEVICE inline float16& operator*=(float16& a, const float16& b) { + a = a * b; + return a; +} + +DEVICE inline float16& operator/=(float16& a, const float16& b) { + a = a / b; + return a; +} + +DEVICE inline bool operator==(const float16& a, const float16& b) { + return __heq(half(a), half(b)); +} + +DEVICE inline bool operator!=(const float16& a, const float16& b) { + return __hne(half(a), half(b)); +} + +DEVICE inline bool operator<(const float16& a, const float16& b) { + return __hlt(half(a), half(b)); +} + +DEVICE inline bool operator<=(const float16& a, const float16& b) { + return __hle(half(a), half(b)); +} + +DEVICE inline bool operator>(const float16& a, const float16& b) { + return __hgt(half(a), half(b)); +} + +DEVICE inline bool operator>=(const float16& a, const float16& b) { + return __hge(half(a), half(b)); +} + +// Arithmetic operators for float16 on ARMv8.2-A CPU +#elif defined(PADDLE_WITH_NATIVE_FP16) HOST inline float16 operator+(const float16& a, const float16& b) { float16 res; asm volatile( @@ -668,7 +737,7 @@ HOST inline bool operator>=(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -// Arithmetic operators, software emulated on other CPU +// Arithmetic operators for float16, software emulated on other CPU/GPU #else HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { return float16(float(a) + float(b)); diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 5e2c460c41..2939895d79 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -86,10 +86,13 @@ class TestDropoutOp5(OpTest): class TestFP16DropoutOp1(OpTest): def setUp(self): x = np.random.random((32, 64)).astype("float16") + prob = 0.35 + out = x * (1.0 - prob) + self.op_type = "dropout" self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'dropout_prob': 0.35, 'fix_seed': True, 'is_test': True} - self.outputs = {'Out': x * (1.0 - self.attrs['dropout_prob'])} + self.attrs = {'dropout_prob': prob, 'fix_seed': True, 'is_test': True} + self.outputs = {'Out': out} def test_check_output(self): if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): @@ -99,10 +102,13 @@ class TestFP16DropoutOp1(OpTest): class TestFP16DropoutOp2(OpTest): def setUp(self): x = np.random.random((32, 64, 3)).astype("float16") + prob = 0.75 + out = x * (1.0 - prob) + self.op_type = "dropout" self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'dropout_prob': 0.75, 'is_test': True} - self.outputs = {'Out': x * (1.0 - self.attrs['dropout_prob'])} + self.attrs = {'dropout_prob': prob, 'is_test': True} + self.outputs = {'Out': out} def test_check_output(self): if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): From f2bbbb2b660ac98def535b8ba41689d196c13127 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 15:52:22 -0700 Subject: [PATCH 18/26] fix arithmetic operator --- paddle/fluid/platform/float16.h | 101 +++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index a68dcc38ac..7c2c6add07 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -484,72 +484,107 @@ DEVICE inline bool operator>=(const half& a, const half& b) { #endif // PADDLE_CUDA_FP16 // Arithmetic operators for float16 on GPU -#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 -DEVICE inline float16 operator+(const float16& a, const float16& b) { +#if defined(PADDLE_CUDA_FP16) +HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hadd(half(a), half(b))); +#else + return float16(float(a) + float(b)); } -DEVICE inline float16 operator-(const float16& a, const float16& b) { +HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hsub(half(a), half(b))); +#else + return float16(float(a) - float(b)); } -DEVICE inline float16 operator*(const float16& a, const float16& b) { +HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hmul(half(a), half(b))); +#else + return float16(float(a) * float(b)); } -DEVICE inline float16 operator/(const float16& a, const float16& b) { - // TODO(kexinzhao): check the cuda version that starts to support __hdiv +HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + // TODO(kexinzhao): check which cuda version starts to support __hdiv float num = __half2float(half(a)); float denom = __half2float(half(b)); return float16(num / denom); +#else + return float16(float(a) / float(b)); } -DEVICE inline float16 operator-(const float16& a) { +HOSTDEVICE inline float16 operator-(const float16& a) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hneg(half(a))); +#else + float16 res; + res.x = a.x ^ 0x8000; + return res; } -DEVICE inline float16& operator+=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { a = a + b; return a; } -DEVICE inline float16& operator-=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { a = a - b; return a; } -DEVICE inline float16& operator*=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { a = a * b; return a; } -DEVICE inline float16& operator/=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { a = a / b; return a; } -DEVICE inline bool operator==(const float16& a, const float16& b) { +HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __heq(half(a), half(b)); +#else + return float(a) == float(b); } -DEVICE inline bool operator!=(const float16& a, const float16& b) { +HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hne(half(a), half(b)); +#else + return float(a) != float(b); } -DEVICE inline bool operator<(const float16& a, const float16& b) { +HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hlt(half(a), half(b)); +#else + return float(a) < float(b); } -DEVICE inline bool operator<=(const float16& a, const float16& b) { +HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hle(half(a), half(b)); +#else + return float(a) <= float(b); } -DEVICE inline bool operator>(const float16& a, const float16& b) { +HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hgt(half(a), half(b)); +#else + return float(a) > float(b); } -DEVICE inline bool operator>=(const float16& a, const float16& b) { +HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hge(half(a), half(b)); +#else + return float(a) >= float(b); } // Arithmetic operators for float16 on ARMv8.2-A CPU @@ -737,71 +772,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -// Arithmetic operators for float16, software emulated on other CPU/GPU +// Arithmetic operators for float16, software emulated on other CPU #else -HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +HOST inline float16 operator+(const float16& a, const float16& b) { return float16(float(a) + float(b)); } -HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +HOST inline float16 operator-(const float16& a, const float16& b) { return float16(float(a) - float(b)); } -HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +HOST inline float16 operator*(const float16& a, const float16& b) { return float16(float(a) * float(b)); } -HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +HOST inline float16 operator/(const float16& a, const float16& b) { return float16(float(a) / float(b)); } -HOSTDEVICE inline float16 operator-(const float16& a) { +HOST inline float16 operator-(const float16& a) { float16 res; res.x = a.x ^ 0x8000; return res; } -HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { +HOST inline float16& operator+=(float16& a, const float16& b) { a = float16(float(a) + float(b)); return a; } -HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { +HOST inline float16& operator-=(float16& a, const float16& b) { a = float16(float(a) - float(b)); return a; } -HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { +HOST inline float16& operator*=(float16& a, const float16& b) { a = float16(float(a) * float(b)); return a; } -HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { +HOST inline float16& operator/=(float16& a, const float16& b) { a = float16(float(a) / float(b)); return a; } -HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +HOST inline bool operator==(const float16& a, const float16& b) { return float(a) == float(b); } -HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +HOST inline bool operator!=(const float16& a, const float16& b) { return float(a) != float(b); } -HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +HOST inline bool operator<(const float16& a, const float16& b) { return float(a) < float(b); } -HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +HOST inline bool operator<=(const float16& a, const float16& b) { return float(a) <= float(b); } -HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +HOST inline bool operator>(const float16& a, const float16& b) { return float(a) > float(b); } -HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +HOST inline bool operator>=(const float16& a, const float16& b) { return float(a) >= float(b); } #endif From 182da95317f7a1f011d46adfb096ac2f6b44e99f Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 16:01:30 -0700 Subject: [PATCH 19/26] small fix --- paddle/fluid/platform/float16.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 7c2c6add07..d3312a47f4 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -490,6 +490,7 @@ HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { return float16(__hadd(half(a), half(b))); #else return float16(float(a) + float(b)); +#endif } HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { @@ -497,6 +498,7 @@ HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { return float16(__hsub(half(a), half(b))); #else return float16(float(a) - float(b)); +#endif } HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { @@ -504,6 +506,7 @@ HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { return float16(__hmul(half(a), half(b))); #else return float16(float(a) * float(b)); +#endif } HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { @@ -514,6 +517,7 @@ HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { return float16(num / denom); #else return float16(float(a) / float(b)); +#endif } HOSTDEVICE inline float16 operator-(const float16& a) { @@ -523,6 +527,7 @@ HOSTDEVICE inline float16 operator-(const float16& a) { float16 res; res.x = a.x ^ 0x8000; return res; +#endif } HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { @@ -550,6 +555,7 @@ HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { return __heq(half(a), half(b)); #else return float(a) == float(b); +#endif } HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { @@ -557,6 +563,7 @@ HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { return __hne(half(a), half(b)); #else return float(a) != float(b); +#endif } HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { @@ -564,6 +571,7 @@ HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { return __hlt(half(a), half(b)); #else return float(a) < float(b); +#endif } HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { @@ -571,6 +579,7 @@ HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { return __hle(half(a), half(b)); #else return float(a) <= float(b); +#endif } HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { @@ -578,6 +587,7 @@ HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { return __hgt(half(a), half(b)); #else return float(a) > float(b); +#endif } HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { @@ -585,6 +595,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { return __hge(half(a), half(b)); #else return float(a) >= float(b); +#endif } // Arithmetic operators for float16 on ARMv8.2-A CPU From 839f4fa2b0d361270ca20f607d127f0ab0aaf8b7 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 20 Mar 2018 09:28:51 +0800 Subject: [PATCH 20/26] move distributed lookup table design to fluid/dist_train --- .../dist_train}/distributed_lookup_table_design.md | 4 ++-- .../design/dist_train/src}/lookup_table.png | Bin .../dist_train/src}/lookup_table_training.png | Bin 3 files changed, 2 insertions(+), 2 deletions(-) rename doc/{design => fluid/design/dist_train}/distributed_lookup_table_design.md (97%) rename doc/{design => fluid/design/dist_train/src}/lookup_table.png (100%) rename doc/{design => fluid/design/dist_train/src}/lookup_table_training.png (100%) diff --git a/doc/design/distributed_lookup_table_design.md b/doc/fluid/design/dist_train/distributed_lookup_table_design.md similarity index 97% rename from doc/design/distributed_lookup_table_design.md rename to doc/fluid/design/dist_train/distributed_lookup_table_design.md index a09f2818c8..e543adf0f9 100644 --- a/doc/design/distributed_lookup_table_design.md +++ b/doc/fluid/design/dist_train/distributed_lookup_table_design.md @@ -26,7 +26,7 @@ lookup of rows. The following figure illustrates the multiplication of x with two non-zero elements, or say, two symbols, and a lookup table W: -![lookup table](./lookup_table.png) +![lookup table](./src/lookup_table.png) ### The Backward Algorithm @@ -42,7 +42,7 @@ or some more sophisticated algorithms that rely on both W' and W: $$W = f(W, W')$$ The following figure illustrates the backward pass of the lookup -operator: ![lookup table training](./lookup_table_training.png) +operator: ![lookup table training](./src/lookup_table_training.png) ## Distributed Storage Service diff --git a/doc/design/lookup_table.png b/doc/fluid/design/dist_train/src/lookup_table.png similarity index 100% rename from doc/design/lookup_table.png rename to doc/fluid/design/dist_train/src/lookup_table.png diff --git a/doc/design/lookup_table_training.png b/doc/fluid/design/dist_train/src/lookup_table_training.png similarity index 100% rename from doc/design/lookup_table_training.png rename to doc/fluid/design/dist_train/src/lookup_table_training.png From b678b826e619a5659dece5248b3867ccf510302f Mon Sep 17 00:00:00 2001 From: Shan Yi <35982308+shanyi15@users.noreply.github.com> Date: Tue, 20 Mar 2018 10:02:20 +0800 Subject: [PATCH 21/26] repair deadlink of fluid doc repair link of "Executor" --- doc/fluid/design/motivation/fluid.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/design/motivation/fluid.md b/doc/fluid/design/motivation/fluid.md index f78fa8c191..110b7d78bf 100644 --- a/doc/fluid/design/motivation/fluid.md +++ b/doc/fluid/design/motivation/fluid.md @@ -103,7 +103,7 @@ In computability theory, a system of data-manipulation rules, such as a programm There are two ways to execute a Fluid program. When a program is executed, it creates a protobuf message [`ProgramDesc`](https://github.com/PaddlePaddle/Paddle/blob/a91efdde6910ce92a78e3aa7157412c4c88d9ee8/paddle/framework/framework.proto#L145) that describes the process and is conceptually like an [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree). -There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program. +There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program. Fluid is moving towards the direction of a compiler, which is explain in [fluid_compiler.md](fluid_compiler.md). From 6f7e812bb3726368165505caca8b752841812497 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 10:50:01 +0800 Subject: [PATCH 22/26] fix bugs --- .../reader/create_double_buffer_reader_op.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 8960fe5d63..bd0bb2ee3b 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader { void start_thread() { buffer_ = framework::MakeChannel(kDoubleBufferSize); - std::thread prefetch([this] { PrefetchThreadFunc(); }); - prefetch.detach(); + prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); } void ReadNext(std::vector* out) override; void ReInit() override; - ~DoubleBufferReader() { buffer_->Close(); } + ~DoubleBufferReader() { + buffer_->Close(); + prefetcher_.join(); + delete buffer_; + } bool HasNext() const override; private: void PrefetchThreadFunc(); + std::thread prefetcher_; framework::Channel* buffer_; platform::Place place_; std::vector> ctxs_; @@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector* out) { void DoubleBufferReader::ReInit() { reader_->ReInit(); buffer_->Close(); + prefetcher_.join(); + delete buffer_; start_thread(); } @@ -159,8 +165,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { if (!buffer_->Send(&batch)) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " - "prefetch thread terminate."; - return; + "prefetch thread will terminate."; + break; } } buffer_->Close(); From 2c225525424909942019eb154d5c9f1f2229d8a6 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 20 Mar 2018 11:08:29 +0800 Subject: [PATCH 23/26] Fix some comments and adapt test_machine_translation.py. --- paddle/fluid/operators/sequence_expand_op.cc | 4 ++-- paddle/fluid/operators/sequence_expand_op.h | 7 ------- python/paddle/fluid/tests/book/test_machine_translation.py | 6 +++--- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.cc b/paddle/fluid/operators/sequence_expand_op.cc index d4bf6034ed..786fe63e75 100644 --- a/paddle/fluid/operators/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_expand_op.cc @@ -145,7 +145,7 @@ and input(Y) [0, 3, 6, 6, 8]] ref_level: 0 then we get 1-level LoDTensor - Out.lod = [[0, 2, 5, 8]] + Out.lod = [[0, 1, 2, 5, 8]] Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]] Out.dims = [8, 1] @@ -157,7 +157,7 @@ Given a common Tensor input(X) and input(Y) Y.lod = [[0, 2, 3, 6]] ref_level: -1 -then we a common Tensor +then we get a common Tensor Out.data = [[a], [a], [b], [c], [c], [c]] Out.dims = [6, 1] diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index eea3cf0440..db7d8bd682 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -37,13 +37,6 @@ class SequenceExpandKernel : public framework::OpKernel { int ref_level = context.Attr("ref_level"); auto& x_lod = x->lod(); auto& y_lod = y->lod(); - PADDLE_ENFORCE_GT(y_lod.size(), 0, - "Level number of `Y`'s lod should be greater than 0."); - PADDLE_ENFORCE( - ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()), - "Invlid `ref_level`, which should be either equal to -1 " - "or in [0, %d)", - y_lod.size()); if (ref_level == -1) ref_level = y_lod.size() - 1; diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index fa38bd3762..3a1a0859ec 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -118,12 +118,12 @@ def decoder_decode(context, is_sparse): is_sparse=is_sparse) # use rnn unit to update rnn - current_state = pd.fc(input=[pre_ids_emb, pre_state_expanded], + current_state = pd.fc(input=[pre_state_expanded, pre_ids_emb], size=decoder_size, act='tanh') - + current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score) # use score to do beam search - current_score = pd.fc(input=current_state, + current_score = pd.fc(input=current_state_with_lod, size=target_dict_dim, act='softmax') topk_scores, topk_indices = pd.topk(current_score, k=50) From 509c8399b8dbf1491fd6adc55b3c423e2d3501be Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 20:16:16 -0700 Subject: [PATCH 24/26] address comments --- .../fluid/tests/unittests/test_dropout_op.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 2939895d79..eaa3435a86 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -83,36 +83,36 @@ class TestDropoutOp5(OpTest): self.check_output() -class TestFP16DropoutOp1(OpTest): +class TestFP16DropoutOp(OpTest): def setUp(self): - x = np.random.random((32, 64)).astype("float16") - prob = 0.35 - out = x * (1.0 - prob) - self.op_type = "dropout" + self.init_test_case() + + x = np.random.random(self.input_size).astype("float16") + out = x * (1.0 - self.prob) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'dropout_prob': prob, 'fix_seed': True, 'is_test': True} + self.attrs = { + 'dropout_prob': self.prob, + 'fix_seed': self.fix_seed, + 'is_test': True + } self.outputs = {'Out': out} + def init_test_case(self): + self.input_size = [32, 64] + self.prob = 0.35 + self.fix_seed = True + def test_check_output(self): if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) -class TestFP16DropoutOp2(OpTest): - def setUp(self): - x = np.random.random((32, 64, 3)).astype("float16") - prob = 0.75 - out = x * (1.0 - prob) - - self.op_type = "dropout" - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'dropout_prob': prob, 'is_test': True} - self.outputs = {'Out': out} - - def test_check_output(self): - if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): - self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) +class TestFP16DropoutOp2(TestFP16DropoutOp): + def init_test_case(self): + self.input_size = [32, 64, 3] + self.prob = 0.75 + self.fix_seed = False if __name__ == '__main__': From 7c59ac484f867e753143b9bf29838d35056f03fa Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 20 Mar 2018 12:37:03 +0800 Subject: [PATCH 25/26] Refine doc and use 'raise' instead of assert --- doc/v2/api/fluid/optimizer.rst | 7 ++++++ python/paddle/fluid/optimizer.py | 42 +++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/doc/v2/api/fluid/optimizer.rst b/doc/v2/api/fluid/optimizer.rst index 9b165f8704..2f820595c3 100644 --- a/doc/v2/api/fluid/optimizer.rst +++ b/doc/v2/api/fluid/optimizer.rst @@ -47,3 +47,10 @@ DecayedAdagrad :members: :noindex: +Adadelta +-------------- + +.. autoclass:: paddle.fluid.optimizer.AdadeltaOptimizer + :members: + :noindex: + diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 75e1de00cf..e8623ee0da 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -583,16 +583,44 @@ class DecayedAdagradOptimizer(Optimizer): class AdadeltaOptimizer(Optimizer): - """Simple Adadelta optimizer with average squared grad state and + """ + **Adadelta Optimizer** + Simple Adadelta optimizer with average squared grad state and average squared update state. + The details of adadelta please refer to this + `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD + `_. + + .. math:: + + E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\ + learning\\_rate &= sqrt( ( E(dx_{t-1}^2) + \\epsilon ) / ( \\ + E(g_t^2) + \\epsilon ) ) \\\\ + E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\\_rate)^2 + + Args: + learning_rate(float): global leraning rate + rho(float): rho in equation + epsilon(float): epsilon in equation + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.Adadelta( + learning_rate=0.0003, epsilon=1.0e-6, rho=0.95) + _, params_grads = optimizer.minimize(cost) """ + _avg_squared_grad_acc_str = "_avg_squared_grad" _avg_squared_update_acc_str = "_avg_squared_update" def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs): - assert learning_rate is not None - assert epsilon is not None - assert rho is not None + if learning_rate is None: + raise ValueError("learning_rate is not set.") + if epsilon is None: + raise ValueError("epsilon is not set.") + if rho is None: + raise ValueError("rho is not set.") super(AdadeltaOptimizer, self).__init__( learning_rate=learning_rate, **kwargs) self.type = "adadelta" @@ -600,14 +628,16 @@ class AdadeltaOptimizer(Optimizer): self._rho = rho def _create_accumulators(self, block, parameters): - assert isinstance(block, framework.Block) + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") for p in parameters: self._add_accumulator(self._avg_squared_grad_acc_str, p) self._add_accumulator(self._avg_squared_update_acc_str, p) def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, framework.Block) + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") avg_squared_grad_acc = self._get_accumulator( self._avg_squared_grad_acc_str, param_and_grad[0]) From c346a345e05e2e17203d693c61be13c541016834 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 16:35:52 +0800 Subject: [PATCH 26/26] fix a bug --- paddle/fluid/recordio/header.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/recordio/header.cc b/paddle/fluid/recordio/header.cc index e50de15b7c..ed09d58f6a 100644 --- a/paddle/fluid/recordio/header.cc +++ b/paddle/fluid/recordio/header.cc @@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) bool Header::Parse(std::istream& is) { uint32_t magic; - size_t read_size = - is.readsome(reinterpret_cast(&magic), sizeof(uint32_t)); + is.read(reinterpret_cast(&magic), sizeof(uint32_t)); + size_t read_size = is.gcount(); if (read_size < sizeof(uint32_t)) { return false; }