From f2c4bb679bb3f247c12888b878ec4dddff358e49 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 21 Jan 2018 09:24:40 -0800 Subject: [PATCH 01/17] Add lstm with recurrent projection operator --- paddle/operators/lstmp_op.cc | 296 ++++++++++++++ paddle/operators/lstmp_op.cu.cc | 24 ++ paddle/operators/lstmp_op.h | 384 ++++++++++++++++++ python/paddle/v2/fluid/tests/test_lstmp_op.py | 314 ++++++++++++++ 4 files changed, 1018 insertions(+) create mode 100644 paddle/operators/lstmp_op.cc create mode 100644 paddle/operators/lstmp_op.cu.cc create mode 100644 paddle/operators/lstmp_op.h create mode 100644 python/paddle/v2/fluid/tests/test_lstmp_op.py diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc new file mode 100644 index 0000000000..4c7f7713ee --- /dev/null +++ b/paddle/operators/lstmp_op.cc @@ -0,0 +1,296 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/lstmp_op.h" + +namespace paddle { +namespace operators { + +class LSTMPOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), + "Input(ProjWeight) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(Bias) of LSTMP should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Projection"), + "Output(Projection) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), + "Output(BatchGate) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), + "Output(BatchGate) of LSTMP should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) and Input(H0) of LSTMP should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + int frame_size = in_dims[1] / 4; + auto w_dims = ctx->GetInputDim("Weight"); + auto proj_dims = ctx->GetInputDim("ProjWeight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "The rank of Input(Weight) should be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], proj_dims[1], + "The first dimension of Input(Weight) " + "should be %d.", + proj_dims[1]); + PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, + "The second dimension of Input(Weight) " + "should be 4 * %d.", + frame_size); + + PADDLE_ENFORCE_EQ(proj_dims.size(), 2, + "The rank of Input(ProjWeight) should be 2."); + PADDLE_ENFORCE_EQ(proj_dims[0], frame_size, + "The first dimension of Input(ProjWeight) " + "should be %d.", + frame_size); + + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + + if (ctx->Attrs().Get("use_peepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes connection", + frame_size); + } + + framework::DDim out_dims({in_dims[0], frame_size}); + framework::DDim proj_out_dims({in_dims[0], proj_dims[1]}); + ctx->SetOutputDim("Projection", proj_out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchCellPreAct", out_dims); + ctx->ShareLoD("Input", "Projection"); + ctx->ShareLoD("Input", "Cell"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LSTMPOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X 4D), where T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size and D is the hidden size.") + .AsDispensable(); + AddInput("C0", + "(Tensor, optional) the initial cell state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time") + .AsDispensable(); + AddInput("Weight", + "(Tensor) the learnable hidden-hidden weights." + " - The shape is (P x 4D), where P is the recurrent projection " + "layer size and D is the hidden size. " + " - Weight = {W_cr, W_ir, W_fr, W_or}"); + AddInput("ProjWeight", + "(Tensor) the learnable weight `W_rh` of the projection layer." + " - The shape is (D x P), where P is the recurrent projection " + "layer size and D is the hidden size."); + AddInput("Bias", + "(Tensor) the learnable weights, which contains two parts: " + "input-hidden bias weight and peephole connections weight if " + "setting `use_peepholes` True. " + "1. `use_peepholes = False` " + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." + "2. `use_peepholes = True` " + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddOutput("Projection", + "(LoDTensor) the projection of the hidden state of LSTMP " + "operator. The shape is (T x P), and lod is the same with the " + "`Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state of LSTMP operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("BatchGate", + "(LoDTensor) This LoDTensor contains input gate, forget gate " + "and output gate after the nonlinear computation. This " + "LoDTensor has the same shape as the reorganized input, which " + "is also be called batch input. The LoD size is 2. The first " + "LoD is the batch offsets and the second LoD contains the " + "indexes, which denote the position of reorganized sequence " + "in the raw input.") + .AsIntermediate(); + AddOutput("BatchCellPreAct", + "(LoDTensor) This LoDTensor is obtained in the forward and used " + "in the backward.") + .AsIntermediate(); + AddAttr("use_peepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed LSTMP.") + .SetDefault(false); + AddAttr( + "gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Long-Short Term Memory with Recurrent Projection (LSTMP) Operator. + +LATMP is stand LSTM appended by a recurrent projection layer to reduce the +number of parameters, espeacially when the output size is relative large. +The formula is as follows: + +$$ +i_t = \sigma(W_{ix}x_{t} + W_{ih}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ + +f_t = \sigma(W_{fx}x_{t} + W_{fh}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ + +c_t = f_t \odot c_{t-1} + i_t \odot act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\ + +o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ + +h_t = o_t \odot act_h(c_t) + +r_t = W_{rh}h_t +$$ + +where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix +of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$ +are diagonal weight matrices for peephole connections. In our implementation, +we use vectors to reprenset these diagonal weight matrices. The b terms +denote bias vectors ($b_i$ is the input gate bias vector), $\sigma$ +is the non-line activations, such as logistic sigmoid function, and +$i, f, o$ and $c$ are the input gate, forget gate, output gate, +and cell activation vectors, respectively, all of which have the same size as +the cell output activation vector $h$. $r$ denotes the recurrent projection +layer. + +The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ +are the cell input and cell output activation functions and `tanh` is usually +used for them. + +Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ +operations on the input $x_{t}$ are NOT included in this operator. +Users can choose to use fully-connect operator before LSTMP operator. + +)DOC"); + } +}; + +class LSTMPGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(Hidden) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Cell"), + "Input(Cell) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(Bias) of LSTMP should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("BatchGate"), + "Input(BatchGate) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), + "Input(BatchGate) of LSTMP should not be null."); + + auto SetOutGradDim = [&ctx](const std::string& name) { + auto g_name = framework::GradVarName(name); + if (ctx->HasOutput(g_name)) + ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); + }; + + SetOutGradDim("Input"); + SetOutGradDim("Weight"); + SetOutGradDim("Bias"); + SetOutGradDim("H0"); + SetOutGradDim("C0"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstmp, ops::LSTMPOp, ops::LSTMPOpMaker, lstmp_grad, + ops::LSTMPGradOp); +REGISTER_OP_CPU_KERNEL( + lstmp, ops::LSTMPKernel, + ops::LSTMPKernel); +REGISTER_OP_CPU_KERNEL( + lstmp_grad, ops::LSTMPGradKernel, + ops::LSTMPGradKernel); diff --git a/paddle/operators/lstmp_op.cu.cc b/paddle/operators/lstmp_op.cu.cc new file mode 100644 index 0000000000..7fcbcfecc8 --- /dev/null +++ b/paddle/operators/lstmp_op.cu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/lstmp_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + lstmp, ops::LSTMPKernel, + ops::LSTMPKernel); +REGISTER_OP_CUDA_KERNEL( + lstmp_grad, + ops::LSTMPGradKernel, + ops::LSTMPGradKernel); diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h new file mode 100644 index 0000000000..f5a38b2ff5 --- /dev/null +++ b/paddle/operators/lstmp_op.h @@ -0,0 +1,384 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/detail/activation_functions.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +inline void ReorderInitState(const DeviceContext& ctx, + const framework::Tensor& src, const size_t* index, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index, *dst, indexed_src); +} + +template +class LSTMPKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* proj_weight = ctx.Input("ProjWeight"); + auto* bias = ctx.Input("Bias"); + + auto* hidden_t0 = ctx.Input("H0"); + auto* cell_t0 = ctx.Input("C0"); + + auto* batch_gate = ctx.Output("BatchGate"); + batch_gate->mutable_data(ctx.GetPlace()); + auto* proj_out = ctx.Output("Projection"); + proj_out->mutable_data(ctx.GetPlace()); + auto* cell_out = ctx.Output("Cell"); + cell_out->mutable_data(ctx.GetPlace()); + + bool is_reverse = ctx.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& device_ctx = ctx.template device_context(); + to_batch(device_ctx, *input, *batch_gate, true, is_reverse); + + auto in_dims = input->dims(); + int frame_size = static_cast(in_dims[1] / 4); + framework::DDim dims({in_dims[0], frame_size}); + framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]}); + + if (bias) { + Tensor b = *bias; + b.Resize({bias->numel(), 1}); + Tensor gate_bias = b.Slice(0, 4 * frame_size); + math::RowwiseAdd add_bias; + add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); + } + + math::LstmMetaValue lstmp_value; + if (bias && ctx.Attr("use_peepholes")) { + T* bias_data = const_cast(bias->data()); + // the code style in LstmpMetaValue will be updated later. + + lstmp_value.check_ig = bias_data + 4 * frame_size; + lstmp_value.check_fg = lstmp_value.check_ig + frame_size; + lstmp_value.check_og = lstmp_value.check_fg + frame_size; + } else { + lstmp_value.check_ig = nullptr; + lstmp_value.check_fg = nullptr; + lstmp_value.check_og = nullptr; + } + lstmp_value.prev_state_value = nullptr; + Tensor ordered_c0; + const size_t* order = batch_gate->lod()[2].data(); + if (cell_t0) { + // Since the batch computing for LSTMP reorders the input sequence + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(device_ctx, *cell_t0, order, + &ordered_c0, true); + lstmp_value.prev_state_value = ordered_c0.data(); + } + + // Use the local variable as here. + LoDTensor batch_hidden, batch_proj, batch_cell; + auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + batch_hidden.mutable_data(dims, ctx.GetPlace()); // T x D + batch_proj.mutable_data(proj_dims, ctx.GetPlace()); // T x P + batch_cell.mutable_data(dims, ctx.GetPlace()); // T x D + batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto gate_act = math::detail::GetActivationType( + ctx.Attr("gate_activation")); + auto cell_act = math::detail::GetActivationType( + ctx.Attr("cell_activation")); + auto cand_act = math::detail::GetActivationType( + ctx.Attr("candidate_activation")); + + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor hidden_t = batch_hidden.Slice(bstart, bend); + Tensor proj_t = batch_proj.Slice(bstart, bend); + Tensor cell_t = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); + + int cur_batch_size = bend - bstart; + + if (n > 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, pre_proj_t, false, *weight, + false, static_cast(1.0), &gate_t, + static_cast(1.0)); + } else if (hidden_t0) { + // If n == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros, the calculation W_h * H0 will be skiped. + // If n == 0 and there is initialized hidden state, calculate W_h * H0. + + // Since the batch computing for LSTMP reorders the input sequence + // according to their length. The initialized hidden state also needs + // to reorder. + Tensor ordered_h0, ordered_proj0; + ordered_proj0.Resize({1, proj_weight->dims()[1]}); + ordered_proj0.mutable_data(ctx.GetPlace()); + ReorderInitState(device_ctx, *hidden_t0, order, + &ordered_h0, true); + math::matmul(device_ctx, ordered_h0, false, + *proj_weight, false, static_cast(1.0), + &ordered_proj0, static_cast(0.0)); + math::matmul(device_ctx, ordered_proj0, false, + *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); + } + + lstmp_value.gate_value = gate_t.data(); + lstmp_value.output_value = hidden_t.data(); + lstmp_value.state_value = cell_t.data(); + lstmp_value.state_active_value = cell_pre_act_t.data(); + math::LstmUnitFunctor::compute( + device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act, + cell_act, cand_act); + lstmp_value.prev_state_value = lstmp_value.state_value; + math::matmul(device_ctx, hidden_t, false, *proj_weight, + false, static_cast(1.0), &proj_t, + static_cast(0.0)); + } + + math::Batch2LoDTensorFunctor to_seq; + batch_proj.set_lod(batch_gate->lod()); + // restore the output hidden in LoDTensor from the batch hidden + to_seq(device_ctx, batch_proj, *proj_out); + + batch_cell.set_lod(batch_gate->lod()); + // restore the output cell state in LoDTensor from the batch cell + to_seq(device_ctx, batch_cell, *cell_out); + } +}; + +template +class LSTMPGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* proj_out = ctx.Input("Projection"); + auto* cell_out = ctx.Input("Cell"); + + auto* batch_gate = ctx.Input("BatchGate"); + auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); + + auto* hidden_g = ctx.Input(framework::GradVarName("Projection")); + + auto* in_g = ctx.Output(framework::GradVarName("Input")); + auto* weight_g = ctx.Output(framework::GradVarName("Weight")); + auto* bias_g = ctx.Output(framework::GradVarName("Bias")); + + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + + auto* h0_g = ctx.Output(framework::GradVarName("H0")); + auto* c0_g = ctx.Output(framework::GradVarName("C0")); + + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + if (weight_g) { + weight_g->mutable_data(ctx.GetPlace()); + zero(device_ctx, weight_g, static_cast(0.0)); + } + + // ordered_h0/c0 is the reordered hidden/cell initialization. + // ordered_h0_g/c0_g is the reordered gradient of hidden/cell + // initialization. + Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; + const size_t* order = batch_gate->lod()[2].data(); + if (c0) { + ReorderInitState(device_ctx, *c0, order, &ordered_c0, + true); + } + if (c0 && c0_g) { + ordered_c0_g.mutable_data(c0_g->dims(), ctx.GetPlace()); + } + + auto in_dims = input->dims(); + auto out_dims = hidden_g->dims(); + int frame_size = static_cast(in_dims[1] / 4); + PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); + + math::LstmMetaValue lstmp_value; + if (bias && ctx.Attr("use_peepholes")) { + T* bias_data = const_cast(bias->data()); + lstmp_value.check_ig = bias_data + 4 * frame_size; + lstmp_value.check_fg = lstmp_value.check_ig + frame_size; + lstmp_value.check_og = lstmp_value.check_fg + frame_size; + } else { + lstmp_value.check_ig = nullptr; + lstmp_value.check_fg = nullptr; + lstmp_value.check_og = nullptr; + } + + math::LstmMetaGrad lstmp_grad; + + if (bias && bias_g) { + bias_g->mutable_data(ctx.GetPlace()); + zero(device_ctx, bias_g, static_cast(0.0)); + } + if (bias && bias_g && ctx.Attr("use_peepholes")) { + T* bias_g_data = bias_g->data(); + lstmp_grad.check_ig_grad = bias_g_data + 4 * frame_size; + lstmp_grad.check_fg_grad = lstmp_grad.check_ig_grad + frame_size; + lstmp_grad.check_og_grad = lstmp_grad.check_fg_grad + frame_size; + } else { + lstmp_grad.check_ig_grad = nullptr; + lstmp_grad.check_fg_grad = nullptr; + lstmp_grad.check_og_grad = nullptr; + } + + math::LoDTensor2BatchFunctor to_batch; + + auto ToBatch = [&batch_gate, &to_batch]( + const DeviceContext& ctx, const framework::LoDTensor& src, + const framework::DDim& dims, framework::LoDTensor& dst) { + dst.mutable_data(dims, ctx.GetPlace()); + dst.set_lod(batch_gate->lod()); + to_batch(ctx, src, dst, false); + }; + + LoDTensor batch_proj, batch_proj_g, batch_cell; + ToBatch(device_ctx, *proj_out, out_dims, batch_proj); + ToBatch(device_ctx, *hidden_g, out_dims, batch_proj_g); + ToBatch(device_ctx, *cell_out, out_dims, batch_cell); + + LoDTensor batch_cell_g, batch_gate_g; + batch_cell_g.mutable_data(out_dims, ctx.GetPlace()); + // TODO(qingqing) support the case output cell has gradient. + // to_batch(device_ctx, *cell_g, batch_cell_g, false); + zero(device_ctx, &batch_cell_g, static_cast(0.0)); + batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); + batch_gate_g.set_lod(batch_gate->lod()); + + auto gate_act = math::detail::GetActivationType( + ctx.Attr("gate_activation")); + auto cell_act = math::detail::GetActivationType( + ctx.Attr("cell_activation")); + auto cand_act = math::detail::GetActivationType( + ctx.Attr("candidate_activation")); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate = batch_gate->Slice(bstart, bend); + Tensor cell = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); + lstmp_value.gate_value = gate.data(); + lstmp_value.state_value = cell.data(); + lstmp_value.state_active_value = cell_pre_act.data(); + + Tensor out_g = batch_proj_g.Slice(bstart, bend); + Tensor gate_g = batch_gate_g.Slice(bstart, bend); + Tensor cell_g = batch_cell_g.Slice(bstart, bend); + lstmp_grad.state_grad = cell_g.data(); + lstmp_grad.gate_grad = gate_g.data(); + lstmp_grad.output_grad = out_g.data(); + + if (n > 0) { + int bstart_pre = static_cast(batch_starts[n - 1]); + Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); + Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); + lstmp_value.prev_state_value = cell_pre.data(); + lstmp_grad.prev_state_grad = cell_pre_g.data(); + } else { + lstmp_value.prev_state_value = c0 ? ordered_c0.data() : nullptr; + lstmp_grad.prev_state_grad = c0_g ? ordered_c0_g.data() : nullptr; + } + + int cur_batch_size = bend - bstart; + math::LstmUnitGradFunctor::compute( + device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size, + gate_act, cell_act, cand_act); + + if (n > 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, gate_g, false, *weight, true, + static_cast(1.0), &pre_proj_g, + static_cast(1.0)); + if (weight_g) { + /* backward weight */ + auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, pre_proj, true, gate_g, + false, static_cast(1.0), weight_g, + static_cast(1.0)); + } + } else { + if (h0 && weight_g) { + ReorderInitState(device_ctx, *h0, order, + &ordered_h0, true); + math::matmul(device_ctx, ordered_h0, true, gate_g, + false, static_cast(1.0), weight_g, + static_cast(1.0)); + } + if (h0 && h0_g) { + ordered_h0_g.mutable_data(h0_g->dims(), ctx.GetPlace()); + math::matmul(device_ctx, gate_g, false, *weight, + true, static_cast(1.0), + &ordered_h0_g, static_cast(0.0)); + } + } + } + + math::Batch2LoDTensorFunctor to_seq; + if (in_g) { + /* backward data */ + in_g->mutable_data(ctx.GetPlace()); + to_seq(device_ctx, batch_gate_g, *in_g); + } + if (bias && bias_g) { + /* backward bias */ + Tensor b_g = *bias_g; + b_g.Resize({bias_g->numel(), 1}); + Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size); + math::ColwiseSum col_sum; + col_sum(device_ctx, batch_gate_g, &gate_bias_g); + } + + if (h0 && h0_g) { + ReorderInitState(device_ctx, ordered_h0_g, order, h0_g, + false); + } + if (c0 && c0_g) { + ReorderInitState(device_ctx, ordered_c0_g, order, c0_g, + false); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py new file mode 100644 index 0000000000..e35712ec06 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -0,0 +1,314 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +from op_test import OpTest + +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + + +def identity(x): + return x + + +def sigmoid(x): + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) + + +def tanh(x): + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +ACTVATION = { + 'identity': identity, + 'sigmoid': sigmoid, + 'tanh': tanh, + 'relu': relu +} + + +# LSTM with recurrent projection Layer +def lstmp( + input, # T x 4D + lod, # 1 x N + h0=None, # N x D + c0=None, # N x D + w_r=None, # P x 5D + w_rh=None, # D x P + w_b=None, # 1 x 4D + w_c=None, # 1 x 3D + is_reverse=False, + act_gate=None, + act_cell=None, + act_cand=None): + def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand): + g = np.dot(r_pre, w_r) # 1 x 4D + g = g + x + g = np.reshape(g, (1, g.size)) + c, g_i, g_f, g_o = np.split(g, 4, axis=1) + if w_c is None: + g_i = act_gate(g_i) # 1 x D + g_f = act_gate(g_f) # 1 x D + else: + w_ic, w_fc, _ = np.split(w_c, 3, axis=1) + g_i = act_gate(g_i + w_ic * c_pre) # 1 x D + g_f = act_gate(g_f + w_fc * c_pre) # 1 x D + c = g_f * c_pre + g_i * act_cand(c) # 1 x D + + if w_c is None: + g_o = act_gate(g_o) # 1 x D + else: + _, _, w_oc = np.split(w_c, 3, axis=1) + g_o = act_gate(g_o + w_oc * c) # 1 x D + h = g_o * act_cell(c) + # projection + r = np.dot(h, w_rh) + return r, c + + def _reverse(x, lod): + y = np.zeros_like(x) + for i in range(len(lod) - 1): + b, e = lod[i], lod[i + 1] + y[b:e, :] = np.flip(x[b:e, :], 0) + return y + + offset = lod[0] + batch_size = len(offset) - 1 + # recurrent projection state + projection = [] + cell = [] + input = _reverse(input, offset) if is_reverse else input + if w_b is not None: + input = input + np.tile(w_b, (offset[-1], 1)) + for i in range(batch_size): + # compute one sequence + seq_len = offset[i + 1] - offset[i] + x = input[offset[i]:offset[i + 1], :] + r_pre = np.dot(h0[i], w_rh) # 1 x P + c_pre = c0[i] # 1 x D + for j in range(seq_len): + # compute one step + r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, + act_cell, act_cand) + projection.append(r_pre.flatten()) + cell.append(c_pre.flatten()) + + projection = np.array(projection).astype('float64') + cell = np.array(cell).astype('float64') + + projection = _reverse(projection, offset) if is_reverse else projection + cell = _reverse(cell, offset) if is_reverse else cell + + assert projection.shape == (input.shape[0], w_r.shape[0]) # T x P + assert cell.shape == (input.shape[0], input.shape[1] / 4) # T x D + return projection, cell + + +class TestLstmOp(OpTest): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + # hidden size + self.D = 16 + # projection size + self.P = 10 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.has_initial_state = False + self.is_reverse = False + self.use_peepholes = True + + def setUp(self): + self.set_argument() + self.op_type = 'lstmp' + + T = self.lod[0][-1] + N = len(self.lod[0]) - 1 + + x = np.random.normal(size=(T, 4 * self.D)).astype('float64') + if self.has_initial_state: + h0 = np.random.normal(size=(N, self.D)).astype('float64') + c0 = np.random.normal(size=(N, self.D)).astype('float64') + else: + h0 = np.zeros((N, self.D)).astype('float64') + c0 = np.zeros((N, self.D)).astype('float64') + w = np.random.normal(size=(self.P, 4 * self.D)).astype('float64') + if self.use_peepholes: + b = np.random.normal(size=(1, 7 * self.D)).astype('float64') + else: + b = np.random.normal(size=(1, 4 * self.D)).astype('float64') + + w_b = b[:, 0:4 * self.D] + w_c = b[:, 4 * self.D:] if self.use_peepholes else None + w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') + r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, + ACTVATION[self.act_gate], ACTVATION[self.act_cell], + ACTVATION[self.act_cand]) + + self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} + + self.inputs['Bias'] = b + + if self.has_initial_state: + self.inputs['H0'] = h0 + self.inputs['C0'] = c0 + + self.outputs = { + 'Projection': (r, self.lod), + 'Cell': (c, self.lod), + } + self.attrs = { + 'use_peepholes': self.use_peepholes, + 'is_reverse': self.is_reverse, + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand + } + + def test_check_output(self): + self.check_output(atol=1e-8) + + """ + def test_check_grad(self): + # TODO(qingqing) remove folowing lines after the check_grad is refined. + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) + """ + + +""" +class TestLstmOpHasInitial(TestLstmOp): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.has_initial_state = True + self.is_reverse = True + self.use_peepholes = True + + def test_check_grad(self): + # TODO(qingqing) remove folowing lines after the check_grad is refined. + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], + max_relative_error=5e-4) + + def test_check_grad_ingore_bias(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Bias')) + + def test_check_grad_ingore_weight(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Bias'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Weight')) + + def test_check_grad_ingore_input(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Weight', 'Bias'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Input')) + + def test_check_grad_ingore_h0(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias', 'C0'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('H0')) + + def test_check_grad_ingore_c0(self): + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias', 'H0'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('C0')) +""" + + +class TestLstmOpRerverse(TestLstmOp): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 + self.P = 10 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.has_initial_state = False + self.is_reverse = True + self.use_peepholes = True + + +class TestLstmOpNotUsePeepholes(TestLstmOp): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 + self.P = 10 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.has_initial_state = False + self.is_reverse = True + self.use_peepholes = False + + +if __name__ == '__main__': + unittest.main() From 22032e49cb88bf8942f96cbbcd28fcdcb553cb50 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 22 Jan 2018 21:12:39 +0800 Subject: [PATCH 02/17] Add python wrapper for multiplex operator. --- doc/api/v2/fluid/layers.rst | 5 ++ python/paddle/v2/fluid/layers/nn.py | 55 +++++++++++++++++++-- python/paddle/v2/fluid/tests/test_layers.py | 10 ++++ 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 986026e0b9..4c4713b17c 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -509,3 +509,8 @@ sequence_reshape ---------------- .. autofunction:: paddle.v2.fluid.layers.sequence_reshape :noindex: + +multiplex +--------- +.. autofunction:: paddle.v2.fluid.layers.multiplex + :noindex: diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index b1db16a83e..26e36ec0bb 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -28,7 +28,7 @@ __all__ = [ 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'sequence_first_step', 'sequence_last_step', 'dropout', 'split', - 'l2_normalize', 'matmul', 'warpctc', 'sequence_reshape' + 'l2_normalize', 'matmul', 'warpctc', 'sequence_reshape', 'multiplex' ] @@ -1813,11 +1813,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): - If both are 2-D, they are multiplied like conventional matrices. - If either is n-D, it is treated as a stack of matrices residing in the - last two dimensions and a batched matrix multiply supporting broadcast + last two dimensions and a batched matrix multiply supporting broadcast applies on the two tensors. - Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and - nontransposed, the prepended or appended dimension :math:`1` will be + Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and + nontransposed, the prepended or appended dimension :math:`1` will be removed after matrix multiplication. Args: @@ -1971,3 +1971,50 @@ def sequence_reshape(input, new_dim): outputs={'Out': [out]}, attrs={'new_dim': new_dim}) return out + + +def multiplex(inputs, index): + """ + **Multiplex Layer** + + Referring to the given index variable, this layer gathers from the input + variables to output a multiplex variable. Assuming that there are :math:`m` + input variables and let :math:`I_i` represents the i-th input variable and i + is in [0, :math:`m`). All input variables are tensors with same shape + [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. Please note that rank of the + input tensor should be at least 2. Each input variable will be viewed as a + 2-D matrix with shape [:math:`M`, :math:`N`] where :math:`M` for :math:`d_0` + and :math:`N` for :math:`d_1` * :math:`d_2` * ... * :math:`d_R`. Let + :math:`I_i[j]` be the j-th row of the i-th input variable. The given index + variable should be a 2-D tensor with shape [:math:`M`, 1]. Let `ID[i]` be + the i-th index value of index variable. Then the output variable will be a + tensor with shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. If we view + the output tensor as a 2-D matrix with shape [:math:`M`, :math:`N`] and let + :math:`O[i]` be the i-th row of the matrix, then values of `O[i]` come from + :math:`I_{ID[i]}[i]`. + + Args: + inputs (list): Input variables which are tensors with same shape and the + rank is at least 2. + index (Variable): Tensor, index variable which is a 2-D tensor with + shape [M, 1] where M for batch size. + + Returns: + Variable: Multiplex variable gathered from input variables. + + Examples: + .. code-block:: python + + x1 = fluid.layers.data(name='x1', shape=[4], dtype='float32') + x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32') + index = fluid.layers.data(name='index', shape=[1], dtype='int32') + out = fluid.layers.multiplex(inputs=[x1, x2], index=index) + """ + helper = LayerHelper('multiplex', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='multiplex', + inputs={'X': inputs, + 'Ids': index}, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 709abd6c6a..dc143f6b9f 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -225,6 +225,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_multiplex(self): + program = Program() + with program_guard(program): + x1 = layers.data(name='x1', shape=[4], dtype='float32') + x2 = layers.data(name='x2', shape=[4], dtype='float32') + index = layers.data(name='index', shape=[1], dtype='int32') + out = layers.multiplex(inputs=[x1, x2], index=index) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() From 552c901204744003f0653dbcc9afe615d5a66334 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 23 Jan 2018 04:04:06 -0800 Subject: [PATCH 03/17] Enable backward computation in lstmp_op --- paddle/operators/lstmp_op.cc | 53 ++++-- .../operators/{lstmp_op.cu.cc => lstmp_op.cu} | 0 paddle/operators/lstmp_op.h | 151 +++++++++++++++--- python/paddle/v2/fluid/tests/test_lstmp_op.py | 59 ++++--- 4 files changed, 206 insertions(+), 57 deletions(-) rename paddle/operators/{lstmp_op.cu.cc => lstmp_op.cu} (100%) diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 4c7f7713ee..266612294c 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -39,21 +39,12 @@ class LSTMPOp : public framework::OperatorWithKernel { "Output(BatchGate) of LSTMP should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), "Output(BatchGate) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), + "Output(BatchHidden) of LSTMP should not be null."); auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); - if (ctx->HasInput("H0")) { - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) and Input(H0) of LSTMP should not " - "be null at the same time."); - auto h_dims = ctx->GetInputDim("H0"); - auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); - } - int frame_size = in_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); auto proj_dims = ctx->GetInputDim("ProjWeight"); @@ -75,6 +66,18 @@ class LSTMPOp : public framework::OperatorWithKernel { "should be %d.", frame_size); + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) and Input(H0) of LSTMP should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]}); + } + auto b_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, @@ -98,6 +101,7 @@ class LSTMPOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("BatchGate", in_dims); ctx->SetOutputDim("BatchCellPreAct", out_dims); + ctx->SetOutputDim("BatchHidden", out_dims); ctx->ShareLoD("Input", "Projection"); ctx->ShareLoD("Input", "Cell"); } @@ -169,6 +173,15 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "(LoDTensor) This LoDTensor is obtained in the forward and used " "in the backward.") .AsIntermediate(); + AddOutput("BatchHidden", + "(LoDTensor) This LoDTensor is obtained in the forward and used " + "in the backward.") + .AsIntermediate(); + AddOutput("OrderedP0", + "(Tensor) the projection of the initial hidden state " + "H0. This is a tensor with shape (N x P), where N is the " + "batch size and P is the hidden size.") + .AsIntermediate(); AddAttr("use_peepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") @@ -177,6 +190,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: False) " "whether to compute reversed LSTMP.") .SetDefault(false); + AddAttr("share_cell_act", + "(bool, defalut: True) " + "whether to share activation with cell output. " + "If false, the projection would be linear, else " + "through an activation same with the cell output.") + .SetDefault(true); AddAttr( "gate_activation", "(string, default: sigmoid)" @@ -213,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ h_t = o_t \odot act_h(c_t) -r_t = W_{rh}h_t +r_t = act_h'(W_{rh}h_t) $$ where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix @@ -229,7 +248,8 @@ layer. The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ are the cell input and cell output activation functions and `tanh` is usually -used for them. +used for them. If `share_cell_act` setted to `False`, $act_h'$ will be linear +else will be same with $act_h$. Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ operations on the input $x_{t}$ are NOT included in this operator. @@ -246,12 +266,14 @@ class LSTMPGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input(Input) of LSTMP should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Hidden"), - "Input(Hidden) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Projection"), + "Input(Projection) of LSTMP should not be null."); PADDLE_ENFORCE(ctx->HasInput("Cell"), "Input(Cell) of LSTMP should not be null."); PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) of LSTMP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), + "Input(ProjWeight) of LSTMP should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), "Input(Bias) of LSTMP should not be null."); @@ -268,6 +290,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { SetOutGradDim("Input"); SetOutGradDim("Weight"); + SetOutGradDim("ProjWeight"); SetOutGradDim("Bias"); SetOutGradDim("H0"); SetOutGradDim("C0"); diff --git a/paddle/operators/lstmp_op.cu.cc b/paddle/operators/lstmp_op.cu similarity index 100% rename from paddle/operators/lstmp_op.cu.cc rename to paddle/operators/lstmp_op.cu diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index f5a38b2ff5..9467ccdb5a 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -13,18 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/op_registry.h" +#include "paddle/operators/activation_op.h" #include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + template inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, const size_t* index, @@ -37,6 +44,21 @@ inline void ReorderInitState(const DeviceContext& ctx, template class LSTMPKernel : public framework::OpKernel { public: + template + void ActCompute(const math::detail::ActivationType act_type, const Device& d, + X x, Y y) const { + if (act_type == math::detail::ActivationType::kIdentity) + y.device(d) = x; + else if (act_type == math::detail::ActivationType::kSigmoid) + SigmoidFunctor()(d, x, y); + else if (act_type == math::detail::ActivationType::kTanh) + TanhFunctor()(d, x, y); + else if (act_type == math::detail::ActivationType::kReLU) + ReluFunctor()(d, x, y); + else + PADDLE_THROW("unsupported activation type"); + } + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* weight = ctx.Input("Weight"); @@ -44,6 +66,7 @@ class LSTMPKernel : public framework::OpKernel { auto* bias = ctx.Input("Bias"); auto* hidden_t0 = ctx.Input("H0"); + auto* ordered_proj0 = ctx.Output("OrderedP0"); auto* cell_t0 = ctx.Input("C0"); auto* batch_gate = ctx.Output("BatchGate"); @@ -97,12 +120,13 @@ class LSTMPKernel : public framework::OpKernel { } // Use the local variable as here. - LoDTensor batch_hidden, batch_proj, batch_cell; + LoDTensor batch_proj, batch_cell; auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); - batch_hidden.mutable_data(dims, ctx.GetPlace()); // T x D + batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); + auto* batch_hidden = ctx.Output("BatchHidden"); + batch_hidden->mutable_data(dims, ctx.GetPlace()); // T x D batch_proj.mutable_data(proj_dims, ctx.GetPlace()); // T x P batch_cell.mutable_data(dims, ctx.GetPlace()); // T x D - batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; @@ -112,13 +136,15 @@ class LSTMPKernel : public framework::OpKernel { ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); + auto share_cell_act = ctx.Attr("share_cell_act"); + auto& place = *ctx.template device_context().eigen_device(); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor hidden_t = batch_hidden.Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); Tensor proj_t = batch_proj.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend); Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); @@ -140,15 +166,19 @@ class LSTMPKernel : public framework::OpKernel { // Since the batch computing for LSTMP reorders the input sequence // according to their length. The initialized hidden state also needs // to reorder. - Tensor ordered_h0, ordered_proj0; - ordered_proj0.Resize({1, proj_weight->dims()[1]}); - ordered_proj0.mutable_data(ctx.GetPlace()); + + Tensor ordered_h0; + ordered_proj0->mutable_data(ctx.GetPlace()); ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0, true); math::matmul(device_ctx, ordered_h0, false, *proj_weight, false, static_cast(1.0), - &ordered_proj0, static_cast(0.0)); - math::matmul(device_ctx, ordered_proj0, false, + ordered_proj0, static_cast(0.0)); + if (share_cell_act) { + auto proj0_dev = EigenMatrix::From(*ordered_proj0); + ActCompute(cell_act, place, proj0_dev, proj0_dev); + } + math::matmul(device_ctx, *ordered_proj0, false, *weight, false, static_cast(1.0), &gate_t, static_cast(1.0)); } @@ -164,6 +194,10 @@ class LSTMPKernel : public framework::OpKernel { math::matmul(device_ctx, hidden_t, false, *proj_weight, false, static_cast(1.0), &proj_t, static_cast(0.0)); + if (share_cell_act) { + auto proj_t_dev = EigenMatrix::From(proj_t); + ActCompute(cell_act, place, proj_t_dev, proj_t_dev); + } } math::Batch2LoDTensorFunctor to_seq; @@ -180,9 +214,26 @@ class LSTMPKernel : public framework::OpKernel { template class LSTMPGradKernel : public framework::OpKernel { public: + template + void ActGradCompute(const math::detail::ActivationType act_type, + const Device& d, X x, Y y, DX dx, DY dy) const { + // x is dummy and won't be used even in Relu(use y instead) + if (act_type == math::detail::ActivationType::kIdentity) + dx.device(d) = dy; + else if (act_type == math::detail::ActivationType::kSigmoid) + SigmoidGradFunctor()(d, x, y, dy, dx); + else if (act_type == math::detail::ActivationType::kTanh) + TanhGradFunctor()(d, x, y, dy, dx); + else if (act_type == math::detail::ActivationType::kReLU) + ReluGradFunctor()(d, x, y, dy, dx); + else + PADDLE_THROW("unsupported activation type"); + } + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* weight = ctx.Input("Weight"); + auto* proj_weight = ctx.Input("ProjWeight"); auto* bias = ctx.Input("Bias"); auto* proj_out = ctx.Input("Projection"); @@ -190,14 +241,19 @@ class LSTMPGradKernel : public framework::OpKernel { auto* batch_gate = ctx.Input("BatchGate"); auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); + auto* batch_hidden = ctx.Input("BatchHidden"); - auto* hidden_g = ctx.Input(framework::GradVarName("Projection")); + auto* projection_g = + ctx.Input(framework::GradVarName("Projection")); auto* in_g = ctx.Output(framework::GradVarName("Input")); auto* weight_g = ctx.Output(framework::GradVarName("Weight")); + auto* proj_weight_g = + ctx.Output(framework::GradVarName("ProjWeight")); auto* bias_g = ctx.Output(framework::GradVarName("Bias")); auto* h0 = ctx.Input("H0"); + auto* ordered_proj0 = ctx.Input("OrderedP0"); auto* c0 = ctx.Input("C0"); auto* h0_g = ctx.Output(framework::GradVarName("H0")); @@ -209,6 +265,10 @@ class LSTMPGradKernel : public framework::OpKernel { weight_g->mutable_data(ctx.GetPlace()); zero(device_ctx, weight_g, static_cast(0.0)); } + if (proj_weight_g) { + proj_weight_g->mutable_data(ctx.GetPlace()); + zero(device_ctx, proj_weight_g, static_cast(0.0)); + } // ordered_h0/c0 is the reordered hidden/cell initialization. // ordered_h0_g/c0_g is the reordered gradient of hidden/cell @@ -224,7 +284,8 @@ class LSTMPGradKernel : public framework::OpKernel { } auto in_dims = input->dims(); - auto out_dims = hidden_g->dims(); + auto out_dims = cell_out->dims(); + framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]}); int frame_size = static_cast(in_dims[1] / 4); PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); @@ -267,10 +328,11 @@ class LSTMPGradKernel : public framework::OpKernel { to_batch(ctx, src, dst, false); }; - LoDTensor batch_proj, batch_proj_g, batch_cell; - ToBatch(device_ctx, *proj_out, out_dims, batch_proj); - ToBatch(device_ctx, *hidden_g, out_dims, batch_proj_g); - ToBatch(device_ctx, *cell_out, out_dims, batch_cell); + LoDTensor batch_hidden_g, batch_proj, batch_proj_g, batch_cell; + batch_hidden_g.mutable_data(out_dims, ctx.GetPlace()); + ToBatch(device_ctx, *proj_out, proj_dims, batch_proj); // T x P + ToBatch(device_ctx, *projection_g, proj_dims, batch_proj_g); // T x P + ToBatch(device_ctx, *cell_out, out_dims, batch_cell); // T x D LoDTensor batch_cell_g, batch_gate_g; batch_cell_g.mutable_data(out_dims, ctx.GetPlace()); @@ -286,6 +348,8 @@ class LSTMPGradKernel : public framework::OpKernel { ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); + auto share_cell_act = ctx.Attr("share_cell_act"); + auto& place = *ctx.template device_context().eigen_device(); auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; @@ -293,6 +357,19 @@ class LSTMPGradKernel : public framework::OpKernel { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); + Tensor cur_proj = batch_proj.Slice(bstart, bend); + Tensor proj_g = batch_proj_g.Slice(bstart, bend); + if (share_cell_act) { + auto cur_proj_dev = EigenMatrix::From(cur_proj); + auto proj_g_dev = EigenMatrix::From(proj_g); + ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, + proj_g_dev); + } + Tensor out_g = batch_hidden_g.Slice(bstart, bend); + math::matmul(device_ctx, proj_g, false, *proj_weight, + true, static_cast(1.0), &out_g, + static_cast(0.0)); + Tensor gate = batch_gate->Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); @@ -300,7 +377,6 @@ class LSTMPGradKernel : public framework::OpKernel { lstmp_value.state_value = cell.data(); lstmp_value.state_active_value = cell_pre_act.data(); - Tensor out_g = batch_proj_g.Slice(bstart, bend); Tensor gate_g = batch_gate_g.Slice(bstart, bend); Tensor cell_g = batch_cell_g.Slice(bstart, bend); lstmp_grad.state_grad = cell_g.data(); @@ -337,19 +413,48 @@ class LSTMPGradKernel : public framework::OpKernel { false, static_cast(1.0), weight_g, static_cast(1.0)); } + if (proj_weight_g) { + /* backward proj weigh */ + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + math::matmul(device_ctx, hidden_t, true, proj_g, + false, static_cast(1.0), + proj_weight_g, static_cast(1.0)); + } } else { if (h0 && weight_g) { ReorderInitState(device_ctx, *h0, order, &ordered_h0, true); - math::matmul(device_ctx, ordered_h0, true, gate_g, - false, static_cast(1.0), weight_g, - static_cast(1.0)); + if (weight_g) { + math::matmul(device_ctx, *ordered_proj0, true, + gate_g, false, static_cast(1.0), + weight_g, static_cast(1.0)); + } } - if (h0 && h0_g) { + if (h0 && (h0_g || proj_weight_g)) { ordered_h0_g.mutable_data(h0_g->dims(), ctx.GetPlace()); + Tensor proj0_g; + proj0_g.Resize({in_dims[0], proj_weight->dims()[1]}); + proj0_g.mutable_data(ctx.GetPlace()); math::matmul(device_ctx, gate_g, false, *weight, - true, static_cast(1.0), - &ordered_h0_g, static_cast(0.0)); + true, static_cast(1.0), &proj0_g, + static_cast(0.0)); + if (share_cell_act) { + auto proj0_dev = EigenMatrix::From(*ordered_proj0); + auto proj0_g_dev = EigenMatrix::From(proj0_g); + ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, + proj0_g_dev); + } + // Tensor proj0_g = proj_g.Slice(bstart, bend); + if (h0_g) { + math::matmul( + device_ctx, proj0_g, false, *proj_weight, true, + static_cast(1.0), &ordered_h0_g, static_cast(0.0)); + } + if (proj_weight_g) { + math::matmul(device_ctx, ordered_h0, true, + proj0_g, false, static_cast(1.0), + proj_weight_g, static_cast(1.0)); + } } } } diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index e35712ec06..81e06063fc 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -62,7 +62,8 @@ def lstmp( is_reverse=False, act_gate=None, act_cell=None, - act_cand=None): + act_cand=None, + share_cell_act=True): def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand): g = np.dot(r_pre, w_r) # 1 x 4D g = g + x @@ -85,6 +86,8 @@ def lstmp( h = g_o * act_cell(c) # projection r = np.dot(h, w_rh) + if share_cell_act: + r = act_cell(r) return r, c def _reverse(x, lod): @@ -107,6 +110,8 @@ def lstmp( seq_len = offset[i + 1] - offset[i] x = input[offset[i]:offset[i + 1], :] r_pre = np.dot(h0[i], w_rh) # 1 x P + if share_cell_act: + r_pre = act_cell(r_pre) c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step @@ -138,6 +143,7 @@ class TestLstmOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' + self.share_cell_act = True self.has_initial_state = False self.is_reverse = False self.use_peepholes = True @@ -167,7 +173,7 @@ class TestLstmOp(OpTest): w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand]) + ACTVATION[self.act_cand], self.share_cell_act) self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} @@ -192,28 +198,30 @@ class TestLstmOp(OpTest): def test_check_output(self): self.check_output(atol=1e-8) - """ def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) - """ + ['Input', 'Weight', 'Bias'], ['Projection'], + max_relative_error=5e-3) -""" class TestLstmOpHasInitial(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 + self.P = 5 self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.share_cell_act = True self.has_initial_state = True self.is_reverse = True self.use_peepholes = True @@ -221,63 +229,74 @@ class TestLstmOpHasInitial(TestLstmOp): def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], - max_relative_error=5e-4) + ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'], + max_relative_error=5e-3) def test_check_grad_ingore_bias(self): N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight'], ['Hidden'], - max_relative_error=5e-4, + ['Input', 'Weight'], ['Projection'], + max_relative_error=5e-3, no_grad_set=set('Bias')) def test_check_grad_ingore_weight(self): N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Bias'], ['Hidden'], - max_relative_error=5e-4, + ['Input', 'Bias'], ['Projection'], + max_relative_error=5e-3, no_grad_set=set('Weight')) def test_check_grad_ingore_input(self): N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Weight', 'Bias'], ['Hidden'], - max_relative_error=5e-4, + ['Weight', 'Bias'], ['Projection'], + max_relative_error=5e-3, no_grad_set=set('Input')) def test_check_grad_ingore_h0(self): N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'C0'], ['Hidden'], - max_relative_error=5e-4, + ['Input', 'Weight', 'Bias', 'C0'], ['Projection'], + max_relative_error=5e-3, no_grad_set=set('H0')) def test_check_grad_ingore_c0(self): N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'H0'], ['Hidden'], - max_relative_error=5e-4, + ['Input', 'Weight', 'Bias', 'H0'], ['Projection'], + max_relative_error=5e-3, no_grad_set=set('C0')) -""" class TestLstmOpRerverse(TestLstmOp): @@ -290,6 +309,7 @@ class TestLstmOpRerverse(TestLstmOp): self.act_cell = 'tanh' self.act_cand = 'tanh' + self.share_cell_act = True self.has_initial_state = False self.is_reverse = True self.use_peepholes = True @@ -305,6 +325,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.act_cell = 'tanh' self.act_cand = 'tanh' + self.share_cell_act = True self.has_initial_state = False self.is_reverse = True self.use_peepholes = False From 7a5b8ffacb2aa64981d6262790501b10257b6321 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 23 Jan 2018 09:51:15 -0800 Subject: [PATCH 04/17] Pass grad checking for projection weight --- paddle/operators/lstmp_op.cc | 4 +- paddle/operators/lstmp_op.h | 18 ++++---- python/paddle/v2/fluid/tests/test_lstmp_op.py | 41 ++++++++++++------- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 266612294c..932e76e913 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -217,7 +217,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Long-Short Term Memory with Recurrent Projection (LSTMP) Operator. -LATMP is stand LSTM appended by a recurrent projection layer to reduce the +LSTMP is stand LSTM appended by a recurrent projection layer to reduce the number of parameters, espeacially when the output size is relative large. The formula is as follows: @@ -232,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ h_t = o_t \odot act_h(c_t) -r_t = act_h'(W_{rh}h_t) +r_t = act_{h'}(W_{rh}h_t) $$ where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index 9467ccdb5a..0048f7e1c6 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -365,10 +365,18 @@ class LSTMPGradKernel : public framework::OpKernel { ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, proj_g_dev); } + /* hidden state backwarad */ Tensor out_g = batch_hidden_g.Slice(bstart, bend); math::matmul(device_ctx, proj_g, false, *proj_weight, true, static_cast(1.0), &out_g, static_cast(0.0)); + /* projection weight backward*/ + if (proj_weight_g) { + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + math::matmul(device_ctx, hidden_t, true, proj_g, + false, static_cast(1.0), + proj_weight_g, static_cast(1.0)); + } Tensor gate = batch_gate->Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend); @@ -407,19 +415,12 @@ class LSTMPGradKernel : public framework::OpKernel { static_cast(1.0), &pre_proj_g, static_cast(1.0)); if (weight_g) { - /* backward weight */ + /* weight backward*/ auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); math::matmul(device_ctx, pre_proj, true, gate_g, false, static_cast(1.0), weight_g, static_cast(1.0)); } - if (proj_weight_g) { - /* backward proj weigh */ - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - math::matmul(device_ctx, hidden_t, true, proj_g, - false, static_cast(1.0), - proj_weight_g, static_cast(1.0)); - } } else { if (h0 && weight_g) { ReorderInitState(device_ctx, *h0, order, @@ -444,7 +445,6 @@ class LSTMPGradKernel : public framework::OpKernel { ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, proj0_g_dev); } - // Tensor proj0_g = proj_g.Slice(bstart, bend); if (h0_g) { math::matmul( device_ctx, proj0_g, false, *proj_weight, true, diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index 81e06063fc..a0f6955d77 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -207,8 +207,8 @@ class TestLstmOp(OpTest): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias'], ['Projection'], - max_relative_error=5e-3) + ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'], + max_relative_error=1e-2) class TestLstmOpHasInitial(TestLstmOp): @@ -235,8 +235,9 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'], - max_relative_error=5e-3) + ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'], + ['Projection'], + max_relative_error=1e-2) def test_check_grad_ingore_bias(self): N = len(self.lod[0]) - 1 @@ -246,8 +247,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'ProjWeight', 'Weight'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('Bias')) def test_check_grad_ingore_weight(self): @@ -258,10 +259,22 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Bias'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'ProjWeight', 'Bias'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('Weight')) + def test_check_grad_ingore_proj_weight(self): + N = len(self.lod[0]) - 1 + self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias'], ['Projection'], + max_relative_error=1e-2, + no_grad_set=set('ProjWeight')) + def test_check_grad_ingore_input(self): N = len(self.lod[0]) - 1 self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') @@ -270,8 +283,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Weight', 'Bias'], ['Projection'], - max_relative_error=5e-3, + ['Weight', 'ProjWeight', 'Bias'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('Input')) def test_check_grad_ingore_h0(self): @@ -282,8 +295,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'C0'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('H0')) def test_check_grad_ingore_c0(self): @@ -294,8 +307,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias', 'H0'], ['Projection'], - max_relative_error=5e-3, + ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'], + max_relative_error=1e-2, no_grad_set=set('C0')) From a3d1f86947dc46dfbff734cf0b5b529eaff4703e Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 23 Jan 2018 10:27:44 -0800 Subject: [PATCH 05/17] Add unit test for linear projection --- python/paddle/v2/fluid/tests/test_lstmp_op.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index a0f6955d77..8835cae504 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -192,7 +192,8 @@ class TestLstmOp(OpTest): 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, - 'candidate_activation': self.act_cand + 'candidate_activation': self.act_cand, + 'share_cell_act': self.share_cell_act } def test_check_output(self): @@ -340,9 +341,25 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.share_cell_act = True self.has_initial_state = False - self.is_reverse = True + self.is_reverse = False self.use_peepholes = False +class TestLstmOpNotShareCellAct(TestLstmOp): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 + self.P = 10 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.share_cell_act = False + self.has_initial_state = False + self.is_reverse = False + self.use_peepholes = True + + if __name__ == '__main__': unittest.main() From db1f6a591ae9291de9877099e6801f101e679969 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 23 Jan 2018 19:29:12 -0800 Subject: [PATCH 06/17] Update doc in lstmp_op --- paddle/operators/lstmp_op.cc | 86 ++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 932e76e913..85be64f44c 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -120,7 +120,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { LSTMPOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", - "(LoDTensor) the first input is a LodTensor, which support " + "(LoDTensor) the input for sequence data, which supports " "variable-time length input sequence. The underlying tensor in " "this LoDTensor is a matrix with shape (T X 4D), where T is the " "total time steps in this mini-batch, D is the hidden size."); @@ -132,21 +132,23 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("C0", "(Tensor, optional) the initial cell state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size. `H0` and `C0` can be NULL but only at the same time") + "batch size. Only one of `H0` and `C0` can be NULL at the same " + "time.") .AsDispensable(); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." - " - The shape is (P x 4D), where P is the recurrent projection " - "layer size and D is the hidden size. " + " - The shape is (P x 4D), where P is the projection layer size " + "and D is the hidden size." " - Weight = {W_cr, W_ir, W_fr, W_or}"); AddInput("ProjWeight", - "(Tensor) the learnable weight `W_rh` of the projection layer." + "(Tensor) the learnable weight of the projection layer." " - The shape is (D x P), where P is the recurrent projection " - "layer size and D is the hidden size."); + "layer size and D is the hidden size." + " - ProjWeight = {W_rh}"); AddInput("Bias", - "(Tensor) the learnable weights, which contains two parts: " - "input-hidden bias weight and peephole connections weight if " - "setting `use_peepholes` True. " + "(Tensor) the learnable biases, which contains two parts: " + "input-hidden biases and peephole connections weights if " + "setting `use_peepholes` to `True`. " "1. `use_peepholes = False` " " - The shape is (1 x 4D). " " - Bias = {b_c, b_i, b_f, b_o}." @@ -155,27 +157,28 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); AddOutput("Projection", "(LoDTensor) the projection of the hidden state of LSTMP " - "operator. The shape is (T x P), and lod is the same with the " + "operator. The shape is (T x P), and LoD is the same with the " "`Input`."); AddOutput("Cell", "(LoDTensor) the cell state of LSTMP operator. " "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " - "and output gate after the nonlinear computation. This " - "LoDTensor has the same shape as the reorganized input, which " - "is also be called batch input. The LoD size is 2. The first " - "LoD is the batch offsets and the second LoD contains the " - "indexes, which denote the position of reorganized sequence " - "in the raw input.") + "and output gate after the activations. This LoDTensor has the " + "same shape as the reorganized input, which is also be called " + "batch input. The LoD size is 2. The first-level LoD is the " + "batch offsets and the second contains the indices, which " + "denotes the position of reorganized sequence in the raw input.") .AsIntermediate(); AddOutput("BatchCellPreAct", - "(LoDTensor) This LoDTensor is obtained in the forward and used " - "in the backward.") + "(LoDTensor) the pre-activation cell state reorganized in batch. " + "This LoDTensor is obtained in the forward and used in the " + "backward.") .AsIntermediate(); AddOutput("BatchHidden", - "(LoDTensor) This LoDTensor is obtained in the forward and used " - "in the backward.") + "(LoDTensor) the hidden state reorganized in batch. " + "This LoDTensor is obtained in the forward and used in the " + "backward.") .AsIntermediate(); AddOutput("OrderedP0", "(Tensor) the projection of the initial hidden state " @@ -190,12 +193,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: False) " "whether to compute reversed LSTMP.") .SetDefault(false); - AddAttr("share_cell_act", - "(bool, defalut: True) " - "whether to share activation with cell output. " - "If false, the projection would be linear, else " - "through an activation same with the cell output.") - .SetDefault(true); AddAttr( "gate_activation", "(string, default: sigmoid)" @@ -214,11 +211,21 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "`tanh` by default.") .SetDefault("tanh") .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("share_cell_act", + "(bool, defalut: True) " + "whether to share the activation of cell output with the " + "projection layer. When set to `False`, the projection " + "is simple linear, otherwise it will go through an " + "activation function same as `cell_activation`.") + .SetDefault(true); AddComment(R"DOC( -Long-Short Term Memory with Recurrent Projection (LSTMP) Operator. +Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator. -LSTMP is stand LSTM appended by a recurrent projection layer to reduce the -number of parameters, espeacially when the output size is relative large. +LSTMP has a separate projection layer after the LSTM layer, projecting the +original hidden state to a lower-dimensional one, which is proposed to reduce +the number of total parameters and furthermore computational complexity for +the LSTM, espeacially for the case that the size of output units is relative +large (https://research.google.com/pubs/archive/43905.pdf). The formula is as follows: $$ @@ -226,13 +233,15 @@ i_t = \sigma(W_{ix}x_{t} + W_{ih}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ f_t = \sigma(W_{fx}x_{t} + W_{fh}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ -c_t = f_t \odot c_{t-1} + i_t \odot act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\ +\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ +c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} + h_t = o_t \odot act_h(c_t) -r_t = act_{h'}(W_{rh}h_t) +r_t = \overline{act_h}(W_{rh}h_t) $$ where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix @@ -240,20 +249,23 @@ of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$ are diagonal weight matrices for peephole connections. In our implementation, we use vectors to reprenset these diagonal weight matrices. The b terms denote bias vectors ($b_i$ is the input gate bias vector), $\sigma$ -is the non-line activations, such as logistic sigmoid function, and +is the activation, such as logistic sigmoid function, and $i, f, o$ and $c$ are the input gate, forget gate, output gate, and cell activation vectors, respectively, all of which have the same size as -the cell output activation vector $h$. $r$ denotes the recurrent projection -layer. +the cell output activation vector $h$. Here $h$ is usually called the hidden +state and $r$ denotes its recurrent projection. And $\tilde{c_t}$ is also +called the candidate hidden state, whose computation is based on the current +input and previous hidden state. The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ are the cell input and cell output activation functions and `tanh` is usually -used for them. If `share_cell_act` setted to `False`, $act_h'$ will be linear -else will be same with $act_h$. +used for them. $\overline{act_h}$ is the activation function for the projection +layer. When `share_cell_act` set to `False`, $\overline{act_h}$ is an +identity activation, otherwise it will be same as $act_h$. Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ operations on the input $x_{t}$ are NOT included in this operator. -Users can choose to use fully-connect operator before LSTMP operator. +Users can choose to use fully-connected operator before LSTMP operator. )DOC"); } From f3fe41078a047790b247be238468a9047e2bb691 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 24 Jan 2018 14:15:17 +0800 Subject: [PATCH 07/17] Fix conflicts and add more supported dtype. --- paddle/operators/multiplex_op.cc | 10 ++++++++-- paddle/operators/multiplex_op.cu | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 78263da2fb..d275fa5cbb 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -119,7 +119,13 @@ REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp); REGISTER_OP_CPU_KERNEL( multiplex, - ops::MultiplexCPUKernel); + ops::MultiplexCPUKernel, + ops::MultiplexCPUKernel, + ops::MultiplexCPUKernel, + ops::MultiplexCPUKernel); REGISTER_OP_CPU_KERNEL( multiplex_grad, - ops::MultiplexGradCPUKernel); + ops::MultiplexGradCPUKernel, + ops::MultiplexGradCPUKernel, + ops::MultiplexGradCPUKernel, + ops::MultiplexGradCPUKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 4372dc2c65..546e6e7a24 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -90,7 +90,13 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( multiplex, - ops::MultiplexGPUKernel); + ops::MultiplexGPUKernel, + ops::MultiplexGPUKernel, + ops::MultiplexGPUKernel, + ops::MultiplexGPUKernel); REGISTER_OP_CUDA_KERNEL( multiplex_grad, - ops::MultiplexGradGPUKernel); + ops::MultiplexGradGPUKernel, + ops::MultiplexGradGPUKernel, + ops::MultiplexGradGPUKernel, + ops::MultiplexGradGPUKernel); From 76beff86a0f8e0d6856691b2968bafa52bf3a859 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 24 Jan 2018 01:34:54 -0800 Subject: [PATCH 08/17] Make the projection activation configurable --- paddle/operators/lstmp_op.cc | 76 +++++++++---------- paddle/operators/lstmp_op.h | 14 ++-- python/paddle/v2/fluid/tests/test_lstmp_op.py | 41 +++++----- 3 files changed, 66 insertions(+), 65 deletions(-) diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 85be64f44c..14469c708d 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTMP should not be null."); + "Input(Input) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Weight"), - "Input(Weight) of LSTMP should not be null."); + "Input(Weight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), - "Input(ProjWeight) of LSTMP should not be null."); + "Input(ProjWeight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTMP should not be null."); + "Input(Bias) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Projection"), - "Output(Projection) of LSTMP should not be null."); + "Output(Projection) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of LSTMP should not be null."); + "Output(Cell) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), - "Output(BatchGate) of LSTMP should not be null."); + "Output(BatchGate) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), - "Output(BatchGate) of LSTMP should not be null."); + "Output(BatchCellPreAct) of LSTMP operator should not be " + "null."); PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), - "Output(BatchHidden) of LSTMP should not be null."); + "Output(BatchHidden) of LSTMP operator should not be null."); auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, + "Input(X)'s rank of LSTMP operator must be 2."); int frame_size = in_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); @@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel { if (ctx->HasInput("H0")) { PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) and Input(H0) of LSTMP should not " - "be null at the same time."); + "Input(C0) of LSTMP operator should not be null after " + "Input(H0) provided."); auto h_dims = ctx->GetInputDim("H0"); auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE(h_dims == c_dims, @@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("C0", "(Tensor, optional) the initial cell state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size. Only one of `H0` and `C0` can be NULL at the same " - "time.") + "batch size. `C0` should not be null if `H0` provided.") .AsDispensable(); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." @@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "`tanh` by default.") .SetDefault("tanh") .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("share_cell_act", - "(bool, defalut: True) " - "whether to share the activation of cell output with the " - "projection layer. When set to `False`, the projection " - "is simple linear, otherwise it will go through an " - "activation function same as `cell_activation`.") - .SetDefault(true); + AddAttr("proj_activation", + "(string, default: tanh)" + "The activation for projection output, " + "`tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator. @@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce the number of total parameters and furthermore computational complexity for the LSTM, espeacially for the case that the size of output units is relative large (https://research.google.com/pubs/archive/43905.pdf). + The formula is as follows: $$ -i_t = \sigma(W_{ix}x_{t} + W_{ih}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ +i_t = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i) \\ -f_t = \sigma(W_{fx}x_{t} + W_{fh}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ +f_t = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f) \\ -\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}r_{t-1} + b_c) \\ +\tilde{c_t} = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c) \\ -o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ +o_t = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_t + b_o) \\ -c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} +c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\ -h_t = o_t \odot act_h(c_t) +h_t = o_t \odot act_h(c_t) \\ r_t = \overline{act_h}(W_{rh}h_t) $$ @@ -259,9 +260,8 @@ input and previous hidden state. The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ are the cell input and cell output activation functions and `tanh` is usually -used for them. $\overline{act_h}$ is the activation function for the projection -layer. When `share_cell_act` set to `False`, $\overline{act_h}$ is an -identity activation, otherwise it will be same as $act_h$. +used for them. $\overline{act_h}$ is the activation function for the +projection output, usually using `identity` or same as $act_h$. Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ operations on the input $x_{t}$ are NOT included in this operator. @@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTMP should not be null."); + "Input(Input) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Projection"), - "Input(Projection) of LSTMP should not be null."); + "Input(Projection) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Cell"), - "Input(Cell) of LSTMP should not be null."); + "Input(Cell) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Weight"), - "Input(Weight) of LSTMP should not be null."); + "Input(Weight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), - "Input(ProjWeight) of LSTMP should not be null."); + "Input(ProjWeight) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTMP should not be null."); + "Input(Bias) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("BatchGate"), - "Input(BatchGate) of LSTMP should not be null."); + "Input(BatchGate) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), - "Input(BatchGate) of LSTMP should not be null."); + "Input(BatchGate) of LSTMP operator should not be null."); auto SetOutGradDim = [&ctx](const std::string& name) { auto g_name = framework::GradVarName(name); diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index 0048f7e1c6..9dc37615f0 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel { ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto share_cell_act = ctx.Attr("share_cell_act"); + auto proj_act = math::detail::GetActivationType( + ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); for (size_t n = 0; n < num_batch; n++) { @@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel { math::matmul(device_ctx, ordered_h0, false, *proj_weight, false, static_cast(1.0), ordered_proj0, static_cast(0.0)); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto proj0_dev = EigenMatrix::From(*ordered_proj0); ActCompute(cell_act, place, proj0_dev, proj0_dev); } @@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel { math::matmul(device_ctx, hidden_t, false, *proj_weight, false, static_cast(1.0), &proj_t, static_cast(0.0)); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto proj_t_dev = EigenMatrix::From(proj_t); ActCompute(cell_act, place, proj_t_dev, proj_t_dev); } @@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel { ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto share_cell_act = ctx.Attr("share_cell_act"); + auto proj_act = math::detail::GetActivationType( + ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); auto batch_starts = batch_gate->lod()[0]; @@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel { Tensor cur_proj = batch_proj.Slice(bstart, bend); Tensor proj_g = batch_proj_g.Slice(bstart, bend); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto cur_proj_dev = EigenMatrix::From(cur_proj); auto proj_g_dev = EigenMatrix::From(proj_g); ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, @@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel { math::matmul(device_ctx, gate_g, false, *weight, true, static_cast(1.0), &proj0_g, static_cast(0.0)); - if (share_cell_act) { + if (proj_act != math::detail::ActivationType::kIdentity) { auto proj0_dev = EigenMatrix::From(*ordered_proj0); auto proj0_g_dev = EigenMatrix::From(proj0_g); ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index 8835cae504..08fc32e117 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -41,7 +41,7 @@ def relu(x): return np.maximum(x, 0) -ACTVATION = { +ACTIVATION = { 'identity': identity, 'sigmoid': sigmoid, 'tanh': tanh, @@ -63,8 +63,9 @@ def lstmp( act_gate=None, act_cell=None, act_cand=None, - share_cell_act=True): - def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand): + act_proj=None): + def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand, + act_proj): g = np.dot(r_pre, w_r) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) @@ -86,8 +87,7 @@ def lstmp( h = g_o * act_cell(c) # projection r = np.dot(h, w_rh) - if share_cell_act: - r = act_cell(r) + r = act_proj(r) return r, c def _reverse(x, lod): @@ -110,13 +110,12 @@ def lstmp( seq_len = offset[i + 1] - offset[i] x = input[offset[i]:offset[i + 1], :] r_pre = np.dot(h0[i], w_rh) # 1 x P - if share_cell_act: - r_pre = act_cell(r_pre) + r_pre = act_proj(r_pre) c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, - act_cell, act_cand) + act_cell, act_cand, act_proj) projection.append(r_pre.flatten()) cell.append(c_pre.flatten()) @@ -131,7 +130,7 @@ def lstmp( return projection, cell -class TestLstmOp(OpTest): +class TestLstmpOp(OpTest): def set_argument(self): self.lod = [[0, 2, 5, 7]] # hidden size @@ -142,8 +141,8 @@ class TestLstmOp(OpTest): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = False self.is_reverse = False self.use_peepholes = True @@ -172,8 +171,8 @@ class TestLstmOp(OpTest): w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, - ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand], self.share_cell_act) + ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand], ACTIVATION[self.act_proj]) self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} @@ -193,7 +192,7 @@ class TestLstmOp(OpTest): 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, 'candidate_activation': self.act_cand, - 'share_cell_act': self.share_cell_act + 'proj_activation': self.act_proj } def test_check_output(self): @@ -212,7 +211,7 @@ class TestLstmOp(OpTest): max_relative_error=1e-2) -class TestLstmOpHasInitial(TestLstmOp): +class TestLstmpOpHasInitial(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = True self.is_reverse = True self.use_peepholes = True @@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp): no_grad_set=set('C0')) -class TestLstmOpRerverse(TestLstmOp): +class TestLstmpOpRerverse(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = False self.is_reverse = True self.use_peepholes = True -class TestLstmOpNotUsePeepholes(TestLstmOp): +class TestLstmpOpNotUsePeepholes(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = self.act_cell - self.share_cell_act = True self.has_initial_state = False self.is_reverse = False self.use_peepholes = False -class TestLstmOpNotShareCellAct(TestLstmOp): +class TestLstmpOpLinearProjection(TestLstmpOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] self.D = 16 @@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp): self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' + self.act_proj = 'identity' - self.share_cell_act = False self.has_initial_state = False self.is_reverse = False self.use_peepholes = True From a249c0cae97e64c89c4db480d0092c19a44d3dfd Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 24 Jan 2018 19:51:37 +0800 Subject: [PATCH 09/17] Refine doc and fix dtype. --- python/paddle/v2/fluid/layers/nn.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 477ae7cea9..d87cedca89 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -2719,22 +2719,22 @@ def multiplex(inputs, index): input variables and let :math:`I_i` represents the i-th input variable and i is in [0, :math:`m`). All input variables are tensors with same shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. Please note that rank of the - input tensor should be at least 2. Each input variable will be viewed as a + input tensor should be at least 2. Each input variable will be treated as a 2-D matrix with shape [:math:`M`, :math:`N`] where :math:`M` for :math:`d_0` and :math:`N` for :math:`d_1` * :math:`d_2` * ... * :math:`d_R`. Let :math:`I_i[j]` be the j-th row of the i-th input variable. The given index variable should be a 2-D tensor with shape [:math:`M`, 1]. Let `ID[i]` be - the i-th index value of index variable. Then the output variable will be a - tensor with shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. If we view - the output tensor as a 2-D matrix with shape [:math:`M`, :math:`N`] and let - :math:`O[i]` be the i-th row of the matrix, then values of `O[i]` come from + the i-th index value of the index variable. Then the output variable will + be a tensor with shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. If we + treat the output tensor as a 2-D matrix with shape [:math:`M`, :math:`N`] + and let :math:`O[i]` be the i-th row of the matrix, then `O[i]` is equal to :math:`I_{ID[i]}[i]`. Args: - inputs (list): Input variables which are tensors with same shape and the - rank is at least 2. + inputs (list): A list of variables to gather from. All variables have the + same shape and the rank is at least 2. index (Variable): Tensor, index variable which is a 2-D tensor - with shape [M, 1] where M for batch size. + with shape [M, 1] where M is the batch size. Returns: Variable: Multiplex variable gathered from input variables. @@ -2748,7 +2748,12 @@ def multiplex(inputs, index): out = fluid.layers.multiplex(inputs=[x1, x2], index=index) """ helper = LayerHelper('multiplex', **locals()) - out = helper.create_tmp_variable(helper.input_dtype()) + + if not isinstance(inputs, list) and len(inputs) < 2: + raise ValueError("inputs should be a list object and contains at least " + "2 elements.") + + out = helper.create_tmp_variable(inputs[0].dtype) helper.append_op( type='multiplex', inputs={'X': inputs, From 9ecc54a11b61e09c3c503b51049394e61f8e1fa3 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 24 Jan 2018 09:00:03 -0800 Subject: [PATCH 10/17] Remove redundant code in unit test --- paddle/operators/lstmp_op.cc | 2 +- paddle/operators/lstmp_op.h | 2 +- python/paddle/v2/fluid/tests/test_lstmp_op.py | 60 +++---------------- 3 files changed, 11 insertions(+), 53 deletions(-) diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc index 14469c708d..c96b30ba35 100644 --- a/paddle/operators/lstmp_op.cc +++ b/paddle/operators/lstmp_op.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index 9dc37615f0..ee82d5c10a 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index 08fc32e117..f137fc61b3 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -131,7 +131,10 @@ def lstmp( class TestLstmpOp(OpTest): - def set_argument(self): + def reset_argument(self): + pass + + def setUp(self): self.lod = [[0, 2, 5, 7]] # hidden size self.D = 16 @@ -147,8 +150,7 @@ class TestLstmpOp(OpTest): self.is_reverse = False self.use_peepholes = True - def setUp(self): - self.set_argument() + self.reset_argument() self.op_type = 'lstmp' T = self.lod[0][-1] @@ -212,19 +214,8 @@ class TestLstmpOp(OpTest): class TestLstmpOpHasInitial(TestLstmpOp): - def set_argument(self): - self.lod = [[0, 2, 5, 7]] - self.D = 16 - self.P = 5 - - self.act_gate = 'sigmoid' - self.act_cell = 'tanh' - self.act_cand = 'tanh' - self.act_proj = self.act_cell - + def reset_argument(self): self.has_initial_state = True - self.is_reverse = True - self.use_peepholes = True def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. @@ -313,52 +304,19 @@ class TestLstmpOpHasInitial(TestLstmpOp): class TestLstmpOpRerverse(TestLstmpOp): - def set_argument(self): - self.lod = [[0, 2, 5, 7]] - self.D = 16 - self.P = 10 - - self.act_gate = 'sigmoid' - self.act_cell = 'tanh' - self.act_cand = 'tanh' - self.act_proj = self.act_cell - - self.has_initial_state = False + def reset_argument(self): self.is_reverse = True - self.use_peepholes = True class TestLstmpOpNotUsePeepholes(TestLstmpOp): - def set_argument(self): - self.lod = [[0, 2, 5, 7]] - self.D = 16 - self.P = 10 - - self.act_gate = 'sigmoid' - self.act_cell = 'tanh' - self.act_cand = 'tanh' - self.act_proj = self.act_cell - - self.has_initial_state = False - self.is_reverse = False + def reset_argument(self): self.use_peepholes = False class TestLstmpOpLinearProjection(TestLstmpOp): - def set_argument(self): - self.lod = [[0, 2, 5, 7]] - self.D = 16 - self.P = 10 - - self.act_gate = 'sigmoid' - self.act_cell = 'tanh' - self.act_cand = 'tanh' + def reset_argument(self): self.act_proj = 'identity' - self.has_initial_state = False - self.is_reverse = False - self.use_peepholes = True - if __name__ == '__main__': unittest.main() From 5c6fc3f92ff05edbc77284a1ec34666eac34646e Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 24 Jan 2018 19:03:12 -0800 Subject: [PATCH 11/17] Make TestLstmpOp inherit from TestLstmOp --- python/paddle/v2/fluid/tests/test_lstm_op.py | 6 +-- python/paddle/v2/fluid/tests/test_lstmp_op.py | 52 +++---------------- 2 files changed, 11 insertions(+), 47 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_lstm_op.py b/python/paddle/v2/fluid/tests/test_lstm_op.py index d9fa01e247..3e79f9d8e1 100644 --- a/python/paddle/v2/fluid/tests/test_lstm_op.py +++ b/python/paddle/v2/fluid/tests/test_lstm_op.py @@ -42,7 +42,7 @@ def relu(x): return np.maximum(x, 0) -ACTVATION = { +ACTIVATION = { 'identity': identity, 'sigmoid': sigmoid, 'tanh': tanh, @@ -158,8 +158,8 @@ class TestLstmOp(OpTest): w_b = b[:, 0:4 * self.D] w_c = b[:, 4 * self.D:] if self.use_peepholes else None h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, - ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand]) + ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) self.inputs = {'Input': (x, self.lod), 'Weight': w} diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index f137fc61b3..92a954a9aa 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -13,39 +13,13 @@ #limitations under the License. import unittest import numpy as np -from op_test import OpTest - -SIGMOID_THRESHOLD_MIN = -40.0 -SIGMOID_THRESHOLD_MAX = 13.0 -EXP_MAX_INPUT = 40.0 - - -def identity(x): - return x - - -def sigmoid(x): - y = np.copy(x) - y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN - y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX - return 1. / (1. + np.exp(-y)) - - -def tanh(x): - y = -2. * x - y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT - return (2. / (1. + np.exp(y))) - 1. - - -def relu(x): - return np.maximum(x, 0) - +import test_lstm_op as LstmTest ACTIVATION = { - 'identity': identity, - 'sigmoid': sigmoid, - 'tanh': tanh, - 'relu': relu + 'identity': LstmTest.identity, + 'sigmoid': LstmTest.sigmoid, + 'tanh': LstmTest.tanh, + 'relu': LstmTest.relu } @@ -55,7 +29,7 @@ def lstmp( lod, # 1 x N h0=None, # N x D c0=None, # N x D - w_r=None, # P x 5D + w_r=None, # P x 4D w_rh=None, # D x P w_b=None, # 1 x 4D w_c=None, # 1 x 3D @@ -130,26 +104,16 @@ def lstmp( return projection, cell -class TestLstmpOp(OpTest): +class TestLstmpOp(LstmTest.TestLstmOp): def reset_argument(self): pass def setUp(self): - self.lod = [[0, 2, 5, 7]] - # hidden size - self.D = 16 + self.set_argument() # projection size self.P = 10 - - self.act_gate = 'sigmoid' - self.act_cell = 'tanh' - self.act_cand = 'tanh' self.act_proj = self.act_cell - self.has_initial_state = False - self.is_reverse = False - self.use_peepholes = True - self.reset_argument() self.op_type = 'lstmp' From 7eb19abc76d58ffc2a4968732a08545b3f8cecb5 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 25 Jan 2018 13:00:22 +0800 Subject: [PATCH 12/17] Refine the doc. --- python/paddle/v2/fluid/layers/nn.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index d87cedca89..bae33f6a15 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -2714,21 +2714,21 @@ def multiplex(inputs, index): """ **Multiplex Layer** - Referring to the given index variable, this layer gathers from the input - variables to output a multiplex variable. Assuming that there are :math:`m` - input variables and let :math:`I_i` represents the i-th input variable and i - is in [0, :math:`m`). All input variables are tensors with same shape - [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. Please note that rank of the - input tensor should be at least 2. Each input variable will be treated as a - 2-D matrix with shape [:math:`M`, :math:`N`] where :math:`M` for :math:`d_0` - and :math:`N` for :math:`d_1` * :math:`d_2` * ... * :math:`d_R`. Let - :math:`I_i[j]` be the j-th row of the i-th input variable. The given index - variable should be a 2-D tensor with shape [:math:`M`, 1]. Let `ID[i]` be - the i-th index value of the index variable. Then the output variable will - be a tensor with shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. If we - treat the output tensor as a 2-D matrix with shape [:math:`M`, :math:`N`] - and let :math:`O[i]` be the i-th row of the matrix, then `O[i]` is equal to - :math:`I_{ID[i]}[i]`. + Referring to the given index variable, this layer selects rows from the + input variables to construct a multiplex variable. Assuming that there are + :math:`m` input variables and :math:`I_i` represents the i-th input + variable and :math:`i` is in [0, :math:`m`). All input variables are + tensors with same shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`]. + Please note that rank of the input tensor should be at least 2. Each input + variable will be treated as a 2-D matrix with shape [:math:`M`, :math:`N`] + where :math:`M` for :math:`d_0` and :math:`N` for :math:`d_1` * :math:`d_2` + * ... * :math:`d_R`. Let :math:`I_i[j]` be the j-th row of the i-th input + variable. The given index variable should be a 2-D tensor with shape + [:math:`M`, 1]. Let `ID[i]` be the i-th index value of the index variable. + Then the output variable will be a tensor with shape [:math:`d_0`, + :math:`d_1`, ..., :math:`d_R`]. If we treat the output tensor as a 2-D + matrix with shape [:math:`M`, :math:`N`] and let :math:`O[i]` be the i-th + row of the matrix, then `O[i]` is equal to :math:`I_{ID[i]}[i]`. Args: inputs (list): A list of variables to gather from. All variables have the From f9a11f51ce4b815c8c456d5d699c675726abdb91 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 25 Jan 2018 14:32:53 +0800 Subject: [PATCH 13/17] Add noavx-openblas whl package download link --- doc/getstarted/build_and_install/pip_install_cn.rst | 1 + doc/getstarted/build_and_install/pip_install_en.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/doc/getstarted/build_and_install/pip_install_cn.rst b/doc/getstarted/build_and_install/pip_install_cn.rst index 0c741e936b..8e4165da6b 100644 --- a/doc/getstarted/build_and_install/pip_install_cn.rst +++ b/doc/getstarted/build_and_install/pip_install_cn.rst @@ -39,6 +39,7 @@ PaddlePaddle可以使用常用的Python包管理工具 "cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "暂无" + "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "暂无" "cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" diff --git a/doc/getstarted/build_and_install/pip_install_en.rst b/doc/getstarted/build_and_install/pip_install_en.rst index 285ed09805..c1e806c0fe 100644 --- a/doc/getstarted/build_and_install/pip_install_en.rst +++ b/doc/getstarted/build_and_install/pip_install_en.rst @@ -42,6 +42,7 @@ If the links below shows up the login form, just click "Log in as guest" to star "cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "Not Available" + "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "Not Available" "cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" From 900f411d629dd4fe18417455055529b86f4455f2 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 25 Jan 2018 17:11:52 +0800 Subject: [PATCH 14/17] fix dist transpiler bug --- python/paddle/v2/fluid/distribute_transpiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 934eba73b8..908810c8be 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -225,7 +225,7 @@ class DistributeTranspiler: if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] - if orig_var == core.VarDesc.VarType.SELECTED_ROWS: + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) @@ -234,7 +234,7 @@ class DistributeTranspiler: inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={"height_sections": height_sections}) - elif orig_var == core.VarDesc.VarType.LOD_TENSOR: + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: sections.append(v.shape[0]) From 7a2e6dead9a8d13ade81e699932b5e78cb6ea64b Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Thu, 25 Jan 2018 18:59:59 +0800 Subject: [PATCH 15/17] fix test_rnn_encoder_decoder --- ...encoder_context.py => test_rnn_encoder_decoder.py} | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) rename python/paddle/v2/fluid/tests/book/{test_machine_translation_encoder_context.py => test_rnn_encoder_decoder.py} (95%) diff --git a/python/paddle/v2/fluid/tests/book/test_machine_translation_encoder_context.py b/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py similarity index 95% rename from python/paddle/v2/fluid/tests/book/test_machine_translation_encoder_context.py rename to python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py index 53ed912c6f..3fd3dbaf77 100644 --- a/python/paddle/v2/fluid/tests/book/test_machine_translation_encoder_context.py +++ b/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py @@ -118,12 +118,13 @@ def seq_to_seq_net(): src_forward, src_backward = bi_lstm_encoder( input_seq=src_embedding, hidden_size=encoder_size) - encoded_vector = fluid.layers.concat( - input=[src_forward, src_backward], axis=1) + src_forward_last = fluid.layers.sequence_last_step(input=src_forward) + src_backward_first = fluid.layers.sequence_first_step(input=src_backward) - enc_vec_last = fluid.layers.sequence_last_step(input=encoded_vector) + encoded_vector = fluid.layers.concat( + input=[src_forward_last, src_backward_first], axis=1) - decoder_boot = fluid.layers.fc(input=enc_vec_last, + decoder_boot = fluid.layers.fc(input=encoded_vector, size=decoder_size, bias_attr=False, act='tanh') @@ -137,7 +138,7 @@ def seq_to_seq_net(): dtype='float32') prediction = lstm_decoder_without_attention(trg_embedding, decoder_boot, - enc_vec_last, decoder_size) + encoded_vector, decoder_size) label = fluid.layers.data( name='label_sequence', shape=[1], dtype='int64', lod_level=1) cost = fluid.layers.cross_entropy(input=prediction, label=label) From 7333df8510953a0c99a1c0702666ff090c005378 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Thu, 25 Jan 2018 21:16:45 +0800 Subject: [PATCH 16/17] fix pool_op bug (#7879) --- paddle/operators/pool_op.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index c3d82ecbde..d6ba5e298a 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -139,10 +139,8 @@ class PoolGradKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); - auto temp = framework::EigenVector::Flatten(*in_x_grad); - temp.device( - *context.template device_context().eigen_device()) = - temp.constant(static_cast(0)); + paddle::operators::math::SetConstant set_constant; + set_constant(dev_ctx, in_x_grad, 0.0); switch (ksize.size()) { case 2: { From 788f5c6d439f2795d9882697be1e257eedfc5c5a Mon Sep 17 00:00:00 2001 From: kexinzhao Date: Thu, 25 Jan 2018 18:20:12 -0800 Subject: [PATCH 17/17] New Run() method for framework::Executor (#7807) * initial commit * add new executor run function * fix bug * fix multiple definition of feed_fetch_method issue * fix cmake * fix tensor copy error * refine executor code * add comments * temporary modification * address comments * fix bug --- paddle/framework/CMakeLists.txt | 4 +- paddle/framework/executor.cc | 164 ++++++++++++++++++++++++++ paddle/framework/executor.h | 6 + paddle/framework/feed_fetch_method.cc | 56 +++++++++ paddle/framework/feed_fetch_method.h | 34 +----- paddle/inference/inference.cc | 14 ++- paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/pybind.cc | 4 +- 8 files changed, 244 insertions(+), 40 deletions(-) create mode 100644 paddle/framework/feed_fetch_method.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8d9260811a..2804969842 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -74,8 +74,10 @@ cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) +cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope -framework_proto backward glog lod_rank_table profiler) +framework_proto backward glog lod_rank_table profiler feed_fetch_method) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index c28ffefdd0..50a70d723e 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "gflags/gflags.h" +#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_tensor_array.h" @@ -149,5 +150,168 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, } } +// Check whether the block already has feed operators and feed_holder. +// Return false if the block does not have any feed operators. +// If some feed operators have been prepended to the block, check that +// the info contained in these feed operators matches the feed_targets +// and feed_holder_name. Raise exception when any mismatch is found. +// Return true if the block has feed operators and holder of matching info. +static bool has_feed_operators( + BlockDesc* block, std::map& feed_targets, + const std::string& feed_holder_name) { + size_t feed_count = 0; + for (auto* op : block->AllOps()) { + if (op->Type() == kFeedOpType) { + feed_count++; + PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, + "Input to feed op should be '%s'", feed_holder_name); + std::string feed_target_name = op->Output("Out")[0]; + PADDLE_ENFORCE( + feed_targets.find(feed_target_name) != feed_targets.end(), + "Feed operator output name '%s' cannot be found in 'feed_targets'", + feed_target_name); + } else { + break; + } + } + + if (feed_count > 0) { + PADDLE_ENFORCE_EQ( + feed_count, feed_targets.size(), + "The number of feed operators should match 'feed_targets'"); + + // When feed operator are present, so should be feed_holder + auto var = block->FindVar(feed_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + feed_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, + "'%s' variable should be 'FEED_MINIBATCH' type", + feed_holder_name); + } + + return feed_count > 0; +} + +// Check whether the block already has fetch operators and fetch_holder. +// Return false if the block does not have any fetch operators. +// If some fetch operators have been appended to the block, check that +// the info contained in these fetch operators matches the fetch_targets +// and fetch_holder_name. Raise exception when any mismatch is found. +// Return true if the block has fetch operators and holder of matching info. +static bool has_fetch_operators( + BlockDesc* block, std::map& fetch_targets, + const std::string& fetch_holder_name) { + size_t fetch_count = 0; + for (auto* op : block->AllOps()) { + if (op->Type() == kFetchOpType) { + fetch_count++; + PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, + "Output of fetch op should be '%s'", fetch_holder_name); + std::string fetch_target_name = op->Input("X")[0]; + PADDLE_ENFORCE( + fetch_targets.find(fetch_target_name) != fetch_targets.end(), + "Fetch operator input name '%s' cannot be found in 'fetch_targets'", + fetch_target_name); + } + } + + if (fetch_count > 0) { + PADDLE_ENFORCE_EQ( + fetch_count, fetch_targets.size(), + "The number of fetch operators should match 'fetch_targets'"); + + // When fetch operator are present, so should be fetch_holder + auto var = block->FindVar(fetch_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + fetch_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, + "'%s' variable should be 'FETCH_LIST' type", + fetch_holder_name); + } + + return fetch_count > 0; +} + +void Executor::Run(const ProgramDesc& program, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + const std::string& feed_holder_name, + const std::string& fetch_holder_name) { + auto* copy_program = new ProgramDesc(program); + auto* global_block = copy_program->MutableBlock(0); + + if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { + // create feed_holder variable + auto* feed_holder = global_block->Var(feed_holder_name); + feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); + feed_holder->SetPersistable(true); + + int i = 0; + for (auto& feed_target : feed_targets) { + std::string var_name = feed_target.first; + VLOG(3) << "feed target's name: " << var_name; + + // prepend feed op + auto* op = global_block->PrependOp(); + op->SetType(kFeedOpType); + op->SetInput("X", {feed_holder_name}); + op->SetOutput("Out", {var_name}); + op->SetAttr("col", {static_cast(i)}); + op->CheckAttrs(); + + i++; + } + } + + // map the data of feed_targets to feed_holder + for (auto* op : global_block->AllOps()) { + if (op->Type() == kFeedOpType) { + std::string feed_target_name = op->Output("Out")[0]; + int idx = boost::get(op->GetAttr("col")); + SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, + idx); + } else { + break; + } + } + + if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { + // create fetch_holder variable + auto* fetch_holder = global_block->Var(fetch_holder_name); + fetch_holder->SetType(proto::VarDesc::FETCH_LIST); + fetch_holder->SetPersistable(true); + + int i = 0; + for (auto& fetch_target : fetch_targets) { + std::string var_name = fetch_target.first; + VLOG(3) << "fetch target's name: " << var_name; + + // append fetch op + auto* op = global_block->AppendOp(); + op->SetType(kFetchOpType); + op->SetInput("X", {var_name}); + op->SetOutput("Out", {fetch_holder_name}); + op->SetAttr("col", {static_cast(i)}); + op->CheckAttrs(); + + i++; + } + } + + Run(*copy_program, scope, 0, true, true); + + // obtain the data of fetch_targets from fetch_holder + for (auto* op : global_block->AllOps()) { + if (op->Type() == kFetchOpType) { + std::string fetch_target_name = op->Input("X")[0]; + int idx = boost::get(op->GetAttr("col")); + *fetch_targets[fetch_target_name] = + GetFetchVariable(*scope, fetch_holder_name, idx); + } + } + + delete copy_program; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index d869e18901..035ff48a52 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -41,6 +41,12 @@ class Executor { void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true, bool create_vars = true); + void Run(const ProgramDesc& program, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + const std::string& feed_holder_name = "feed", + const std::string& fetch_holder_name = "fetch"); + private: const platform::Place place_; }; diff --git a/paddle/framework/feed_fetch_method.cc b/paddle/framework/feed_fetch_method.cc new file mode 100644 index 0000000000..21201b6755 --- /dev/null +++ b/paddle/framework/feed_fetch_method.cc @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/feed_fetch_method.h" +#include "glog/logging.h" +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +void SetFeedVariable(Scope* scope, const LoDTensor& input, + const std::string& var_name, size_t index) { + // If var_name Variable is not found in GlobalScope, a new variable will + // be created. + VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; + Variable* g_feed_value = scope->Var(var_name); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + if (index >= feed_inputs.size()) { + feed_inputs.resize(index + 1); + } + // shared data with input tensor + feed_inputs[index].ShareDataWith(input); + // set lod + feed_inputs[index].set_lod(input.lod()); +} + +LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, + size_t index) { + // Since we want to fetch LodTensor from a variable, the variable must + // be created alreadly. + Variable* g_fetch_value = scope.FindVar(var_name); + PADDLE_ENFORCE(g_fetch_value->IsType(), + "Only %s can be invoked by GetFetchVariable", + typeid(FeedFetchList).name()); + auto& fetch_outputs = *g_fetch_value->GetMutable(); + auto& tensor = fetch_outputs[index]; + VLOG(3) << "Fetch " << var_name << " with index " << index + << " shape= " << tensor.dims(); + PADDLE_ENFORCE_LT(index, fetch_outputs.size()); + return tensor; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h index 7feacb1e24..b71945fcc8 100644 --- a/paddle/framework/feed_fetch_method.h +++ b/paddle/framework/feed_fetch_method.h @@ -13,46 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "glog/logging.h" + #include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/scope.h" -#include "paddle/framework/variable.h" namespace paddle { namespace framework { void SetFeedVariable(Scope* scope, const LoDTensor& input, - const std::string& var_name, size_t index) { - // If var_name Variable is not found in GlobalScope, a new variable will - // be created. - VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; - Variable* g_feed_value = scope->Var(var_name); - auto& feed_inputs = - *(g_feed_value->GetMutable>()); - if (index >= feed_inputs.size()) { - feed_inputs.resize(index + 1); - } - // shared data with input tensor - feed_inputs[index].ShareDataWith(input); - // set lod - feed_inputs[index].set_lod(input.lod()); -} + const std::string& var_name, size_t index); LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, - size_t index) { - // Since we want to fetch LodTensor from a variable, the variable must - // be created alreadly. - Variable* g_fetch_value = scope.FindVar(var_name); - PADDLE_ENFORCE(g_fetch_value->IsType(), - "Only %s can be invoked by GetFetchVariable", - typeid(FeedFetchList).name()); - auto& fetch_outputs = *g_fetch_value->GetMutable(); - auto& tensor = fetch_outputs[index]; - VLOG(3) << "Fetch " << var_name << " with index " << index - << " shape= " << tensor.dims(); - PADDLE_ENFORCE_LT(index, fetch_outputs.size()); - return tensor; -} + size_t index); } // namespace framework } // namespace paddle diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 09268ffb3a..b43c359ed1 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "inference.h" #include #include "paddle/framework/executor.h" -#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/init.h" #include "paddle/framework/scope.h" @@ -154,7 +153,7 @@ void InferenceEngine::Execute(const std::vector& feeds, LOG(FATAL) << "Please initialize the program_ and load_program_ first."; } - if (feeds.size() < feed_var_names_.size()) { + if (feeds.size() != feed_var_names_.size()) { LOG(FATAL) << "Please feed " << feed_var_names_.size() << " input Tensors."; } @@ -165,19 +164,22 @@ void InferenceEngine::Execute(const std::vector& feeds, executor->Run(*load_program_, scope, 0, true, true); + std::map feed_targets; + std::map fetch_targets; + // set_feed_variable for (size_t i = 0; i < feed_var_names_.size(); ++i) { - framework::SetFeedVariable(scope, feeds[i], "feed", i); + feed_targets[feed_var_names_[i]] = &feeds[i]; } - executor->Run(*program_, scope, 0, true, true); - // get_fetch_variable fetchs.resize(fetch_var_names_.size()); for (size_t i = 0; i < fetch_var_names_.size(); ++i) { - fetchs[i] = framework::GetFetchVariable(*scope, "fetch", i); + fetch_targets[fetch_var_names_[i]] = &fetchs[i]; } + executor->Run(*program_, scope, feed_targets, fetch_targets); + delete place; delete scope; delete executor; diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index e78673e0ba..de53fea0dd 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc const_value.cc - DEPS pybind python backward proto_desc paddle_memory executor prune init profiler + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID) target_link_libraries(paddle_pybind rt) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index b4fd2a8989..490397afdd 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -424,7 +424,9 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) - .def("run", &Executor::Run); + .def("run", + (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & + Executor::Run); m.def("unique_integer", UniqueIntegerGenerator); m.def("init_gflags", framework::InitGflags);