You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							476 lines
						
					
					
						
							20 KiB
						
					
					
				
			
		
		
	
	
							476 lines
						
					
					
						
							20 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #include "paddle/fluid/operators/gru_op.h"
 | |
| #include <memory>
 | |
| #include <string>
 | |
| #include "paddle/fluid/operators/math/blas.h"
 | |
| #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
 | |
| #include "paddle/fluid/operators/math/detail/gru_kernel.h"
 | |
| 
 | |
| DECLARE_int32(paddle_num_threads);
 | |
| 
 | |
| namespace paddle {
 | |
| namespace operators {
 | |
| 
 | |
| using framework::Tensor;
 | |
| 
 | |
| class GRUOp : public framework::OperatorWithKernel {
 | |
|  public:
 | |
|   using framework::OperatorWithKernel::OperatorWithKernel;
 | |
| 
 | |
|   void InferShape(framework::InferShapeContext* ctx) const override {
 | |
|     OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
 | |
|     OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
 | |
|     OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
 | |
|     OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
 | |
|                    "BatchResetHiddenPrev", "GRU");
 | |
|     OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
 | |
|                    "GRU");
 | |
|     OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
 | |
| 
 | |
|     auto input_dims = ctx->GetInputDim("Input");
 | |
|     auto weight_dims = ctx->GetInputDim("Weight");
 | |
|     int input_size = input_dims[1];
 | |
|     int frame_size = weight_dims[0];
 | |
|     if (ctx->IsRuntime()) {
 | |
|       PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
 | |
|                         platform::errors::InvalidArgument(
 | |
|                             "The second dimension of Input(Input) must be 3 "
 | |
|                             "times of frame_size in GRUOp, but received %d "
 | |
|                             "(Input) vs %d (frame_size).",
 | |
|                             input_size, frame_size));
 | |
|     }
 | |
|     PADDLE_ENFORCE_EQ(
 | |
|         weight_dims[1], frame_size * 3,
 | |
|         platform::errors::InvalidArgument(
 | |
|             "The shape of Input(Weight) matrix must be [frame_size, frame_size "
 | |
|             "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
 | |
|             weight_dims[0], weight_dims[1], frame_size, frame_size * 3));
 | |
|     if (ctx->HasInput("H0")) {
 | |
|       auto h0_dims = ctx->GetInputDim("H0");
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           h0_dims[1], frame_size,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The width of Input(H0) must be equal to frame_size, but "
 | |
|               "received %d (width of H0) vs %d (frame_size).",
 | |
|               h0_dims[1], frame_size));
 | |
|     }
 | |
|     if (ctx->HasInput("Bias")) {
 | |
|       auto bias_dims = ctx->GetInputDim("Bias");
 | |
|       int bias_height = bias_dims[0];
 | |
|       int bias_width = bias_dims[1];
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           bias_height, 1,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The shape of Bias must be [1, frame_size * 3], but received "
 | |
|               "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
 | |
|               bias_height, bias_width, frame_size * 3));
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           bias_width, frame_size * 3,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The shape of Bias must be [1, frame_size * 3], but received "
 | |
|               "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
 | |
|               bias_height, bias_width, frame_size * 3));
 | |
|     }
 | |
|     ctx->SetOutputDim("BatchGate", input_dims);
 | |
|     ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
 | |
|     ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
 | |
|     ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
 | |
|     ctx->ShareLoD("Input", "Hidden");
 | |
|   }
 | |
| };
 | |
| 
 | |
| class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
 | |
|  public:
 | |
|   void Make() override {
 | |
|     AddInput("Input",
 | |
|              "(LoDTensor) The first input is a LodTensor, which supports "
 | |
|              "variable-time length input sequence. The underlying tensor in "
 | |
|              "this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
 | |
|              "total time steps in this mini-batch, D is the hidden size.");
 | |
|     AddInput("H0",
 | |
|              "(Tensor, optional) The initial hidden state is an optional "
 | |
|              "input. This is a tensor with shape (N x D), where N is the "
 | |
|              "batch size, D is the hidden size.")
 | |
|         .AsDispensable();
 | |
|     AddInput(
 | |
|         "Weight",
 | |
|         "(Tensor) The learnable hidden-hidden weight matrix with shape "
 | |
|         "(D x 3D), where D is the hidden size. The elements continuous in "
 | |
|         "memory can be divided into two parts. The first part are weights of "
 | |
|         "the update gate and reset gate with shape (D x 2D), and the second "
 | |
|         "part are weights of output candidate with shape (D x D).");
 | |
|     AddInput("Bias",
 | |
|              "(Tensor, optional) Bias vector with shape (1 x 3D) concating "
 | |
|              "bias of the update gate, reset gate and output candidate.")
 | |
|         .AsDispensable();
 | |
|     AddOutput("BatchGate",
 | |
|               "(LoDTensor) To compute with batches, sequence data will be "
 | |
|               "reorganized into several successive batches each containing "
 | |
|               "data from the same time step. The LoDTensor BatchGate contains "
 | |
|               "the update gate, reset gate and output candidate values "
 | |
|               "organized in batches. The LoD size is 2. The first LoD contains "
 | |
|               "the batch offsets and the second LoD contains the indexes in "
 | |
|               "the raw sequence data.")
 | |
|         .AsIntermediate();
 | |
|     AddOutput(
 | |
|         "BatchResetHiddenPrev",
 | |
|         "(LoDTensor) The reset hidden state LoDTensor organized in batches. "
 | |
|         "This LoDTensor is a matrix with shape (T X D) and has the same LoD "
 | |
|         "with `BatchGate`.")
 | |
|         .AsIntermediate();
 | |
|     AddOutput(
 | |
|         "BatchHidden",
 | |
|         "(LoDTensor) The hidden state LoDTensor organized in batches.  "
 | |
|         "This LoDTensor is a matrix with shape (T X D) and has the same LoD "
 | |
|         "with `BatchGate`.")
 | |
|         .AsIntermediate();
 | |
|     AddOutput(
 | |
|         "Hidden",
 | |
|         "(LoDTensor) the hidden state LoDTensor organized in sequences. "
 | |
|         "This LoDTensor is a matrix with shape (T X D) and has the same LoD "
 | |
|         "with `BatchGate`.");
 | |
|     AddAttr<std::string>("activation",
 | |
|                          "(string, default tanh) "
 | |
|                          "The activation type used for output candidate {h}_t.")
 | |
|         .SetDefault("tanh");
 | |
|     AddAttr<std::string>(
 | |
|         "gate_activation",
 | |
|         "(string, default sigmoid) "
 | |
|         "The activation type used in update gate and reset gate.")
 | |
|         .SetDefault("sigmoid");
 | |
|     AddAttr<bool>("is_reverse",
 | |
|                   "(bool, default: False) "
 | |
|                   "whether to compute reversed GRU.")
 | |
|         .SetDefault(false);
 | |
|     AddAttr<bool>("origin_mode",
 | |
|                   "bool"
 | |
|                   "use origin mode in article https://arxiv.org/abs/1412.3555")
 | |
|         .SetDefault(false);
 | |
|     AddComment(R"DOC(
 | |
| GRU Operator implements part calculations of the complete GRU as following:
 | |
| 
 | |
| $$
 | |
| update\_gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
 | |
| reset\_gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r)  \\
 | |
| output\_candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
 | |
| output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
 | |
| $$
 | |
| 
 | |
| @note To implement the complete GRU, fully-connected operator must be used
 | |
| before to feed xu, xr and xc as the Input of GRU operator.
 | |
| )DOC");
 | |
|   }
 | |
| };
 | |
| 
 | |
| class GRUGradOp : public framework::OperatorWithKernel {
 | |
|  public:
 | |
|   using framework::OperatorWithKernel::OperatorWithKernel;
 | |
| 
 | |
|   void InferShape(framework::InferShapeContext* ctx) const override {
 | |
|     OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU@Grad");
 | |
|     OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU@Grad");
 | |
|     OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
 | |
|                    "GRU@Grad");
 | |
|     OP_INOUT_CHECK(ctx->HasInput("BatchResetHiddenPrev"), "Input",
 | |
|                    "BatchResetHiddenPrev", "GRU@Grad");
 | |
|     OP_INOUT_CHECK(ctx->HasInput("BatchHidden"), "Input", "BatchHidden",
 | |
|                    "GRU@Grad");
 | |
|     OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "GRU@Grad");
 | |
|     OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
 | |
|                    framework::GradVarName("Hidden"), "GRU@Grad");
 | |
| 
 | |
|     auto input_dims = ctx->GetInputDim("Input");
 | |
|     auto weight_dims = ctx->GetInputDim("Weight");
 | |
|     int input_size = input_dims[1];
 | |
|     int frame_size = weight_dims[0];
 | |
|     int weight_height = weight_dims[0];
 | |
|     int weight_width = weight_dims[1];
 | |
|     PADDLE_ENFORCE_EQ(
 | |
|         input_size, frame_size * 3,
 | |
|         platform::errors::InvalidArgument(
 | |
|             "The second dimension of Input(Input) must be 3 times of "
 | |
|             "frame_size in GRUOp, but received %d (Input) vs %d (frame_size).",
 | |
|             input_size, frame_size));
 | |
|     PADDLE_ENFORCE_EQ(
 | |
|         weight_height, frame_size,
 | |
|         platform::errors::InvalidArgument(
 | |
|             "The shape of Input(Weight) matrix must be [frame_size, frame_size "
 | |
|             "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
 | |
|             weight_height, weight_width, frame_size, frame_size * 3));
 | |
|     PADDLE_ENFORCE_EQ(
 | |
|         weight_width, frame_size * 3,
 | |
|         platform::errors::InvalidArgument(
 | |
|             "The shape of Input(Weight) matrix must be [frame_size, frame_size "
 | |
|             "* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
 | |
|             weight_height, weight_width, frame_size, frame_size * 3));
 | |
|     if (ctx->HasInput("H0")) {
 | |
|       auto h0_dims = ctx->GetInputDim("H0");
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           h0_dims[1], frame_size,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The width of Input(H0) must be equal to frame_size, but "
 | |
|               "received %d (width of H0) vs %d (frame_size).",
 | |
|               h0_dims[1], frame_size));
 | |
|       auto h0_grad_name = framework::GradVarName("H0");
 | |
|       if (ctx->HasOutput(h0_grad_name))
 | |
|         ctx->SetOutputDim(h0_grad_name, h0_dims);
 | |
|     }
 | |
|     if (ctx->HasInput("Bias")) {
 | |
|       auto bias_dims = ctx->GetInputDim("Bias");
 | |
|       int bias_height = bias_dims[0];
 | |
|       int bias_width = bias_dims[1];
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           bias_height, 1,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The shape of Bias must be [1, frame_size * 3], but received "
 | |
|               "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
 | |
|               bias_height, bias_width, frame_size * 3));
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           bias_width, frame_size * 3,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The shape of Bias must be [1, frame_size * 3], but received "
 | |
|               "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
 | |
|               bias_height, bias_width, frame_size * 3));
 | |
|       auto bias_grad_name = framework::GradVarName("Bias");
 | |
|       if (ctx->HasOutput(bias_grad_name))
 | |
|         ctx->SetOutputDim(bias_grad_name, bias_dims);
 | |
|     }
 | |
|     auto input_grad_name = framework::GradVarName("Input");
 | |
|     if (ctx->HasOutput(input_grad_name))
 | |
|       ctx->SetOutputDim(input_grad_name, input_dims);
 | |
|     auto weight_grad_name = framework::GradVarName("Weight");
 | |
|     if (ctx->HasOutput(weight_grad_name))
 | |
|       ctx->SetOutputDim(weight_grad_name, weight_dims);
 | |
|   }
 | |
| 
 | |
|   framework::OpKernelType GetExpectedKernelType(
 | |
|       const framework::ExecutionContext& ctx) const override {
 | |
|     return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
 | |
|                                        ctx, framework::GradVarName("Hidden")),
 | |
|                                    ctx.device_context());
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| class GRUCPUKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   void BatchCompute(const framework::ExecutionContext& context) const {
 | |
|     using DeviceContext = paddle::platform::CPUDeviceContext;
 | |
|     bool origin_mode = context.Attr<bool>("origin_mode");
 | |
|     auto* input = context.Input<LoDTensor>("Input");
 | |
|     auto* h0 = context.Input<Tensor>("H0");
 | |
|     auto* weight = context.Input<Tensor>("Weight");
 | |
|     const T* weight_data = weight->data<T>();
 | |
|     auto* bias = context.Input<Tensor>("Bias");
 | |
|     auto* batch_gate = context.Output<LoDTensor>("BatchGate");
 | |
|     batch_gate->mutable_data<T>(context.GetPlace());
 | |
|     auto* batch_reset_hidden_prev =
 | |
|         context.Output<LoDTensor>("BatchResetHiddenPrev");
 | |
|     batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
 | |
|     auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
 | |
|     batch_hidden->mutable_data<T>(context.GetPlace());
 | |
|     auto* hidden = context.Output<LoDTensor>("Hidden");
 | |
|     hidden->mutable_data<T>(context.GetPlace());
 | |
| 
 | |
|     auto hidden_dims = hidden->dims();
 | |
| 
 | |
|     bool is_reverse = context.Attr<bool>("is_reverse");
 | |
|     math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
 | |
|     auto& dev_ctx = context.template device_context<DeviceContext>();
 | |
|     to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
 | |
| 
 | |
|     if (bias) {
 | |
|       math::RowwiseAdd<DeviceContext, T> add_bias;
 | |
|       add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
 | |
|     }
 | |
| 
 | |
|     int frame_size = hidden_dims[1];
 | |
|     math::GRUMetaValue<T> gru_value;
 | |
|     gru_value.gate_weight = const_cast<T*>(weight_data);
 | |
|     gru_value.state_weight =
 | |
|         const_cast<T*>(weight_data + 2 * frame_size * frame_size);
 | |
|     Tensor ordered_h0;
 | |
| 
 | |
|     framework::Vector<size_t> order(batch_gate->lod()[2]);
 | |
| 
 | |
|     if (h0) {
 | |
|       // Since the batch computing for GRU reorders the input sequences
 | |
|       // according to their length. The initialized cell state also needs
 | |
|       // to reorder.
 | |
|       ReorderInitState<DeviceContext, T>(
 | |
|           context.template device_context<DeviceContext>(), *h0, order,
 | |
|           &ordered_h0, true);
 | |
|       gru_value.prev_out_value = ordered_h0.data<T>();
 | |
|     } else {
 | |
|       gru_value.prev_out_value = nullptr;
 | |
|     }
 | |
|     auto batch_starts = batch_gate->lod()[0];
 | |
|     size_t seq_len = batch_starts.size() - 1;
 | |
|     auto active_node = math::detail::GetActivationType(
 | |
|         context.Attr<std::string>("activation"));
 | |
|     auto active_gate = math::detail::GetActivationType(
 | |
|         context.Attr<std::string>("gate_activation"));
 | |
| 
 | |
| #ifdef PADDLE_WITH_MKLML
 | |
|     // use MKL packed to speedup GEMM
 | |
|     if (FLAGS_paddle_num_threads >= 4) {
 | |
|       auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
 | |
|       T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
 | |
|                                        frame_size * 2 /*width of weight*/,
 | |
|                                        frame_size /*height of height*/);
 | |
|       PADDLE_ENFORCE_NOT_NULL(
 | |
|           packed_gate, platform::errors::NotFound(
 | |
|                            "The caculation result of packed_gate by "
 | |
|                            "GEMM_ALLOC should not be null when using MKL."));
 | |
|       blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
 | |
|                      frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
 | |
|                      packed_gate);
 | |
|       T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
 | |
|                                         frame_size /*width of weight*/,
 | |
|                                         frame_size /*height of height*/);
 | |
|       PADDLE_ENFORCE_NOT_NULL(
 | |
|           packed_state, platform::errors::NotFound(
 | |
|                             "The caculation result of packed_state by "
 | |
|                             "GEMM_ALLOC should not be null when using MKL."));
 | |
|       blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
 | |
|                      frame_size, T(1.0), gru_value.state_weight, frame_size,
 | |
|                      packed_state);
 | |
|       for (size_t n = 0; n < seq_len; n++) {
 | |
|         int bstart = static_cast<int>(batch_starts[n]);
 | |
|         int bend = static_cast<int>(batch_starts[n + 1]);
 | |
|         int cur_batch_size = bend - bstart;
 | |
| 
 | |
|         Tensor gate_t = batch_gate->Slice(bstart, bend);
 | |
|         Tensor reset_hidden_prev_t =
 | |
|             batch_reset_hidden_prev->Slice(bstart, bend);
 | |
|         Tensor hidden_t = batch_hidden->Slice(bstart, bend);
 | |
|         gru_value.output_value = hidden_t.data<T>();
 | |
|         gru_value.gate_value = gate_t.data<T>();
 | |
|         gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
 | |
| 
 | |
|         if (gru_value.prev_out_value) {
 | |
|           blas.GEMM_COMPUTE(
 | |
|               CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
 | |
|               frame_size, gru_value.prev_out_value, frame_size, packed_gate,
 | |
|               frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
 | |
|         }
 | |
| 
 | |
|         math::detail::forward_reset_output(
 | |
|             math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
 | |
|             cur_batch_size, active_gate);
 | |
| 
 | |
|         if (gru_value.prev_out_value) {
 | |
|           blas.GEMM_COMPUTE(
 | |
|               CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
 | |
|               gru_value.reset_output_value, frame_size, packed_state,
 | |
|               frame_size, T(1), gru_value.gate_value + frame_size * 2,
 | |
|               frame_size * 3);
 | |
|         }
 | |
| 
 | |
|         math::detail::forward_final_output(
 | |
|             math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
 | |
|             cur_batch_size, active_node, origin_mode);
 | |
| 
 | |
|         gru_value.prev_out_value = gru_value.output_value;
 | |
|       }
 | |
| 
 | |
|       blas.GEMM_FREE(packed_gate);
 | |
|       blas.GEMM_FREE(packed_state);
 | |
|     } else {
 | |
| #endif
 | |
|       for (size_t n = 0; n < seq_len; n++) {
 | |
|         int bstart = static_cast<int>(batch_starts[n]);
 | |
|         int bend = static_cast<int>(batch_starts[n + 1]);
 | |
|         int cur_batch_size = bend - bstart;
 | |
| 
 | |
|         Tensor gate_t = batch_gate->Slice(bstart, bend);
 | |
|         Tensor reset_hidden_prev_t =
 | |
|             batch_reset_hidden_prev->Slice(bstart, bend);
 | |
|         Tensor hidden_t = batch_hidden->Slice(bstart, bend);
 | |
|         gru_value.output_value = hidden_t.data<T>();
 | |
|         gru_value.gate_value = gate_t.data<T>();
 | |
|         gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
 | |
| 
 | |
|         math::GRUUnitFunctor<DeviceContext, T>::compute(
 | |
|             dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
 | |
|             active_gate, origin_mode);
 | |
| 
 | |
|         gru_value.prev_out_value = gru_value.output_value;
 | |
|       }
 | |
| #ifdef PADDLE_WITH_MKLML
 | |
|     }
 | |
| #endif
 | |
|     math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
 | |
|     batch_hidden->set_lod(batch_gate->lod());
 | |
|     to_seq(dev_ctx, *batch_hidden, hidden);
 | |
|   }
 | |
| 
 | |
|   void Compute(const framework::ExecutionContext& context) const override {
 | |
|     BatchCompute(context);
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <typename T>
 | |
| class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
 | |
|  public:
 | |
|   using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
 | |
| 
 | |
|  protected:
 | |
|   void Apply(GradOpPtr<T> grad_op) const override {
 | |
|     grad_op->SetType("gru_grad");
 | |
|     grad_op->SetInput("Input", this->Input("Input"));
 | |
|     grad_op->SetInput("H0", this->Input("H0"));
 | |
|     grad_op->SetInput("Bias", this->Input("Bias"));
 | |
|     grad_op->SetInput("Weight", this->Input("Weight"));
 | |
| 
 | |
|     grad_op->SetInput("BatchGate", this->Output("BatchGate"));
 | |
|     grad_op->SetInput("BatchResetHiddenPrev",
 | |
|                       this->Output("BatchResetHiddenPrev"));
 | |
|     grad_op->SetInput("BatchHidden", this->Output("BatchHidden"));
 | |
|     grad_op->SetInput("Hidden", this->Output("Hidden"));
 | |
| 
 | |
|     grad_op->SetInput(framework::GradVarName("Hidden"),
 | |
|                       this->OutputGrad("Hidden"));
 | |
| 
 | |
|     grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0"));
 | |
|     grad_op->SetOutput(framework::GradVarName("Input"),
 | |
|                        this->InputGrad("Input"));
 | |
|     grad_op->SetOutput(framework::GradVarName("Weight"),
 | |
|                        this->InputGrad("Weight"));
 | |
|     grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
 | |
| 
 | |
|     grad_op->SetAttrMap(this->Attrs());
 | |
|   }
 | |
| };
 | |
| 
 | |
| DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInferer, "Input",
 | |
|                                     "Bias");
 | |
| 
 | |
| }  // namespace operators
 | |
| }  // namespace paddle
 | |
| 
 | |
| namespace ops = paddle::operators;
 | |
| REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
 | |
|                   ops::GRUGradOpMaker<paddle::framework::OpDesc>,
 | |
|                   ops::GRUGradOpMaker<paddle::imperative::OpBase>);
 | |
| REGISTER_OPERATOR(gru_grad, ops::GRUGradOp,
 | |
|                   ops::GRUGradOpNoNeedBufferVarInferer);
 | |
| REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
 | |
|                        ops::GRUCPUKernel<double>);
 | |
| REGISTER_OP_CPU_KERNEL(
 | |
|     gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float>,
 | |
|     ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double>);
 |