commit
1d7954fc3f
@ -0,0 +1,220 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/gru_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
class GRUOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||||
|
"Input(%s) of GRUOp should not be null.", "Input");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
||||||
|
"Input(%s) of GRUOp should not be null.", "Weight");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
|
||||||
|
"Output(%s) of GRUOp should not be null.", "BatchGate");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
|
||||||
|
"Output(%s) of GRUOp should not be null.",
|
||||||
|
"BatchResetHiddenPrev");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
|
||||||
|
"Output(%s) of GRUOp should not be null.", "BatchHidden");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
||||||
|
"Output(%s) of GRUOp should not be null.", "Hidden");
|
||||||
|
auto input_dims = ctx->GetInputDim("Input");
|
||||||
|
auto weight_dims = ctx->GetInputDim("Weight");
|
||||||
|
int input_size = input_dims[1];
|
||||||
|
int frame_size = weight_dims[0];
|
||||||
|
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
|
||||||
|
"The input_size must be 3 times of frame_size in GRUOp.");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_dims[1], frame_size * 3,
|
||||||
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||||
|
if (ctx->HasInput("H0")) {
|
||||||
|
auto h0_dims = ctx->GetInputDim("H0");
|
||||||
|
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
|
||||||
|
"The width of H0 must be equal to frame_size.");
|
||||||
|
}
|
||||||
|
if (ctx->HasInput("Bias")) {
|
||||||
|
auto bias_dims = ctx->GetInputDim("Bias");
|
||||||
|
int bias_height = bias_dims[0];
|
||||||
|
int bias_width = bias_dims[1];
|
||||||
|
PADDLE_ENFORCE_EQ(bias_height, 1,
|
||||||
|
"The shape of Bias must be [1, frame_size * 3].");
|
||||||
|
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
||||||
|
"The shape of Bias must be [1, frame_size * 3].");
|
||||||
|
}
|
||||||
|
ctx->SetOutputDim("BatchGate", input_dims);
|
||||||
|
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
|
||||||
|
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
|
||||||
|
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
|
||||||
|
ctx->ShareLoD("Input", "Hidden");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("Input",
|
||||||
|
"(LoDTensor) The first input is a LodTensor, which supports "
|
||||||
|
"variable-time length input sequence. The underlying tensor in "
|
||||||
|
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
|
||||||
|
"total time steps in this mini-batch, D is the hidden size.");
|
||||||
|
AddInput("H0",
|
||||||
|
"(Tensor, optional) The initial hidden state is an optional "
|
||||||
|
"input. This is a tensor with shape (N x D), where N is the "
|
||||||
|
"batch size, D is the hidden size.")
|
||||||
|
.AsDispensable();
|
||||||
|
AddInput(
|
||||||
|
"Weight",
|
||||||
|
"(Tensor) The learnable hidden-hidden weight matrix with shape "
|
||||||
|
"(D x 3D), where D is the hidden size. The elements continuous in "
|
||||||
|
"memory can be divided into two parts. The first part are weights of "
|
||||||
|
"the update gate and reset gate with shape (D x 2D), and the second "
|
||||||
|
"part are weights of output candidate with shape (D x D).");
|
||||||
|
AddInput("Bias",
|
||||||
|
"(Tensor, optional) Bias vector with shape (1 x 3D) concating "
|
||||||
|
"bias of the update gate, reset gate and output candidate.")
|
||||||
|
.AsDispensable();
|
||||||
|
AddOutput("BatchGate",
|
||||||
|
"(LoDTensor) To compute with batches, sequence data will be "
|
||||||
|
"reorganized into several successive batches each containing "
|
||||||
|
"data from the same time step. The LoDTensor BatchGate contains "
|
||||||
|
"the update gate, reset gate and output candidate values "
|
||||||
|
"organized in batches. The LoD size is 2. The first LoD contains "
|
||||||
|
"the batch offsets and the second LoD contains the indexes in "
|
||||||
|
"the raw sequence data.")
|
||||||
|
.AsIntermediate();
|
||||||
|
AddOutput(
|
||||||
|
"BatchResetHiddenPrev",
|
||||||
|
"(LoDTensor) The reseted hidden state LoDTensor organized in batches. "
|
||||||
|
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
|
||||||
|
"with `BatchGate`.")
|
||||||
|
.AsIntermediate();
|
||||||
|
AddOutput(
|
||||||
|
"BatchHidden",
|
||||||
|
"(LoDTensor) The hidden state LoDTensor organized in batches. "
|
||||||
|
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
|
||||||
|
"with `BatchGate`.")
|
||||||
|
.AsIntermediate();
|
||||||
|
AddOutput(
|
||||||
|
"Hidden",
|
||||||
|
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
|
||||||
|
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
|
||||||
|
"with `BatchGate`.");
|
||||||
|
AddAttr<std::string>("activation",
|
||||||
|
"(string, default tanh) "
|
||||||
|
"The activation type used for output candidate {h}_t.")
|
||||||
|
.SetDefault("tanh");
|
||||||
|
AddAttr<std::string>(
|
||||||
|
"gate_activation",
|
||||||
|
"(string, default sigmoid) "
|
||||||
|
"The activation type used in update gate and reset gate.")
|
||||||
|
.SetDefault("sigmoid");
|
||||||
|
AddAttr<bool>("is_reverse",
|
||||||
|
"(bool, defalut: False) "
|
||||||
|
"whether to compute reversed GRU.")
|
||||||
|
.SetDefault(false);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
GRU Operator implements part calculations of the complete GRU as following:
|
||||||
|
|
||||||
|
\f[
|
||||||
|
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
|
||||||
|
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
|
||||||
|
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
|
||||||
|
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
|
||||||
|
\f]
|
||||||
|
|
||||||
|
@note To implement the complete GRU, fully-connected operator must be used
|
||||||
|
before to feed xu, xr and xc as the Input of GRU operator.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class GRUGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||||
|
"Input(%s) of GRUGradOp should not be null.", "Input");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
||||||
|
"Input(%s) of GRUGradOp should not be null.", "Weight");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
|
||||||
|
"Input(%s) of GRUGradOp should not be null.", "BatchGate");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"),
|
||||||
|
"Input(%s) of GRUGradOp should not be null.",
|
||||||
|
"BatchResetHiddenPrev");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("BatchHidden"),
|
||||||
|
"Input(%s) of GRUOp should not be null.", "BatchHidden");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
|
||||||
|
"Input(%s) of GRUGradOp should not be null.", "Hidden");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
||||||
|
"Input(%s@GRAD) of GRUGradOp should not be null.", "Hidden");
|
||||||
|
auto input_dims = ctx->GetInputDim("Input");
|
||||||
|
auto weight_dims = ctx->GetInputDim("Weight");
|
||||||
|
int input_size = input_dims[1];
|
||||||
|
int frame_size = weight_dims[0];
|
||||||
|
int weight_height = weight_dims[0];
|
||||||
|
int weight_width = weight_dims[1];
|
||||||
|
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
|
||||||
|
"The input_size must be 3 times of frame_size in GRUOp.");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_height, frame_size,
|
||||||
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_width, frame_size * 3,
|
||||||
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||||
|
if (ctx->HasInput("H0")) {
|
||||||
|
auto h0_dims = ctx->GetInputDim("H0");
|
||||||
|
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
|
||||||
|
"The width of H0 must be equal to frame_size.");
|
||||||
|
auto h0_grad_name = framework::GradVarName("H0");
|
||||||
|
if (ctx->HasOutput(h0_grad_name))
|
||||||
|
ctx->SetOutputDim(h0_grad_name, h0_dims);
|
||||||
|
}
|
||||||
|
if (ctx->HasInput("Bias")) {
|
||||||
|
auto bias_dims = ctx->GetInputDim("Bias");
|
||||||
|
int bias_height = bias_dims[0];
|
||||||
|
int bias_width = bias_dims[1];
|
||||||
|
PADDLE_ENFORCE_EQ(bias_height, 1,
|
||||||
|
"The shape of Bias must be [1, frame_size * 3].");
|
||||||
|
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
||||||
|
"The shape of Bias must be [1, frame_size * 3].");
|
||||||
|
auto bias_grad_name = framework::GradVarName("Bias");
|
||||||
|
if (ctx->HasOutput(bias_grad_name))
|
||||||
|
ctx->SetOutputDim(bias_grad_name, bias_dims);
|
||||||
|
}
|
||||||
|
auto input_grad_name = framework::GradVarName("Input");
|
||||||
|
if (ctx->HasOutput(input_grad_name))
|
||||||
|
ctx->SetOutputDim(input_grad_name, input_dims);
|
||||||
|
auto weight_grad_name = framework::GradVarName("Weight");
|
||||||
|
if (ctx->HasOutput(weight_grad_name))
|
||||||
|
ctx->SetOutputDim(weight_grad_name, weight_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP(gru, ops::GRUOp, ops::GRUOpMaker, gru_grad, ops::GRUGradOp);
|
||||||
|
REGISTER_OP_CPU_KERNEL(gru, ops::GRUKernel<paddle::platform::CPUPlace, float>,
|
||||||
|
ops::GRUKernel<paddle::platform::CPUPlace, double>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(gru_grad,
|
||||||
|
ops::GRUGradKernel<paddle::platform::CPUPlace, float>,
|
||||||
|
ops::GRUGradKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#include "paddle/operators/gru_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_GPU_KERNEL(gru, ops::GRUKernel<paddle::platform::GPUPlace, float>,
|
||||||
|
ops::GRUKernel<paddle::platform::GPUPlace, double>);
|
||||||
|
REGISTER_OP_GPU_KERNEL(gru_grad,
|
||||||
|
ops::GRUGradKernel<paddle::platform::GPUPlace, float>,
|
||||||
|
ops::GRUGradKernel<paddle::platform::GPUPlace, double>);
|
@ -0,0 +1,231 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/operators/math/gru_compute.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
#include "paddle/operators/math/sequence2batch.h"
|
||||||
|
|
||||||
|
#include "paddle/framework/eigen.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class GRUKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void BatchCompute(const framework::ExecutionContext& context) const {
|
||||||
|
auto* input = context.Input<LoDTensor>("Input");
|
||||||
|
auto* h0 = context.Input<Tensor>("H0");
|
||||||
|
const T* h0_data = h0 ? h0->data<T>() : nullptr;
|
||||||
|
auto* weight = context.Input<Tensor>("Weight");
|
||||||
|
const T* weight_data = weight->data<T>();
|
||||||
|
auto* bias = context.Input<Tensor>("Bias");
|
||||||
|
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
|
||||||
|
batch_gate->mutable_data<T>(context.GetPlace());
|
||||||
|
auto* batch_reset_hidden_prev =
|
||||||
|
context.Output<LoDTensor>("BatchResetHiddenPrev");
|
||||||
|
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
|
||||||
|
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
|
||||||
|
batch_hidden->mutable_data<T>(context.GetPlace());
|
||||||
|
auto* hidden = context.Output<LoDTensor>("Hidden");
|
||||||
|
hidden->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
context.ShareLoD("Input", "Hidden");
|
||||||
|
|
||||||
|
auto hidden_dims = hidden->dims();
|
||||||
|
|
||||||
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
||||||
|
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
||||||
|
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
|
||||||
|
|
||||||
|
int frame_size = hidden_dims[1];
|
||||||
|
int batch_size = hidden_dims[0];
|
||||||
|
auto g = EigenMatrix<T>::From(*batch_gate);
|
||||||
|
auto place = context.GetEigenDevice<Place>();
|
||||||
|
if (bias) {
|
||||||
|
auto b = EigenMatrix<T>::From(*bias);
|
||||||
|
g.device(place) = g +
|
||||||
|
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
|
||||||
|
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
math::hl_gru_value<T> gru_value;
|
||||||
|
gru_value.gateWeight = const_cast<T*>(weight_data);
|
||||||
|
gru_value.stateWeight =
|
||||||
|
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
|
||||||
|
gru_value.prevOutValue = const_cast<T*>(h0_data);
|
||||||
|
auto batch_starts = batch_gate->lod()[0];
|
||||||
|
size_t num_batch = batch_starts.size() - 1;
|
||||||
|
for (size_t n = 0; n < num_batch; n++) {
|
||||||
|
int bstart = static_cast<int>(batch_starts[n]);
|
||||||
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
||||||
|
int cur_batch_size = bend - bstart;
|
||||||
|
|
||||||
|
Tensor gate_t = batch_gate->Slice(bstart, bend);
|
||||||
|
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
|
||||||
|
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
|
||||||
|
gru_value.outputValue = hidden_t.data<T>();
|
||||||
|
gru_value.gateValue = gate_t.data<T>();
|
||||||
|
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
|
||||||
|
math::GRUUnitFunctor<Place, T>::compute(
|
||||||
|
context.device_context(), gru_value, frame_size, cur_batch_size,
|
||||||
|
math::ActiveType(context.Attr<std::string>("activation")),
|
||||||
|
math::ActiveType(context.Attr<std::string>("gate_activation")));
|
||||||
|
gru_value.prevOutValue = gru_value.outputValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
math::Batch2LoDTensorFunctor<Place, T> to_seq;
|
||||||
|
batch_hidden->set_lod(batch_gate->lod());
|
||||||
|
to_seq(context.device_context(), *batch_hidden, *hidden);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
BatchCompute(context);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class GRUGradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void BatchCompute(const framework::ExecutionContext& context) const {
|
||||||
|
auto* h0 = context.Input<Tensor>("H0");
|
||||||
|
const T* h0_data = h0 ? h0->data<T>() : nullptr;
|
||||||
|
auto* weight = context.Input<Tensor>("Weight");
|
||||||
|
const T* weight_data = weight->data<T>();
|
||||||
|
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
|
||||||
|
auto* batch_reset_hidden_prev =
|
||||||
|
context.Input<LoDTensor>("BatchResetHiddenPrev");
|
||||||
|
auto* batch_hidden = context.Input<LoDTensor>("BatchHidden");
|
||||||
|
auto* hidden = context.Input<LoDTensor>("Hidden");
|
||||||
|
auto* hidden_grad =
|
||||||
|
context.Input<LoDTensor>(framework::GradVarName("Hidden"));
|
||||||
|
auto* input_grad =
|
||||||
|
context.Output<LoDTensor>(framework::GradVarName("Input"));
|
||||||
|
auto* h0_grad = context.Output<Tensor>(framework::GradVarName("H0"));
|
||||||
|
auto* weight_grad =
|
||||||
|
context.Output<Tensor>(framework::GradVarName("Weight"));
|
||||||
|
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
|
||||||
|
|
||||||
|
auto gate_dims = batch_gate->dims();
|
||||||
|
auto hidden_dims = hidden->dims();
|
||||||
|
int frame_size = hidden_dims[1];
|
||||||
|
|
||||||
|
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
||||||
|
LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
|
||||||
|
batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
|
||||||
|
batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
|
||||||
|
batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
|
||||||
|
context.GetPlace());
|
||||||
|
math::SetConstant<Place, T> zero;
|
||||||
|
zero(context.device_context(), &batch_hidden_grad, static_cast<T>(0.0));
|
||||||
|
zero(context.device_context(), &batch_gate_grad, static_cast<T>(0.0));
|
||||||
|
zero(context.device_context(), &batch_reset_hidden_prev_grad,
|
||||||
|
static_cast<T>(0.0));
|
||||||
|
|
||||||
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
||||||
|
batch_hidden_grad.set_lod(batch_hidden->lod());
|
||||||
|
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
|
||||||
|
is_reverse);
|
||||||
|
|
||||||
|
math::hl_gru_value<T> gru_value;
|
||||||
|
gru_value.gateWeight = const_cast<T*>(weight_data);
|
||||||
|
gru_value.stateWeight =
|
||||||
|
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
|
||||||
|
|
||||||
|
math::hl_gru_grad<T> gru_grad;
|
||||||
|
if (weight_grad) {
|
||||||
|
gru_grad.gateWeightGrad =
|
||||||
|
weight_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
zero(context.device_context(), weight_grad, static_cast<T>(0.0));
|
||||||
|
gru_grad.stateWeightGrad =
|
||||||
|
weight_grad->data<T>() + 2 * frame_size * frame_size;
|
||||||
|
} else {
|
||||||
|
gru_grad.gateWeightGrad = nullptr;
|
||||||
|
gru_grad.stateWeightGrad = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto batch_starts = batch_hidden_grad.lod()[0];
|
||||||
|
size_t num_batch = batch_starts.size() - 1;
|
||||||
|
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
|
||||||
|
int bstart = static_cast<int>(batch_starts[n]);
|
||||||
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
||||||
|
int cur_batch_size = bend - bstart;
|
||||||
|
|
||||||
|
Tensor gate_t = batch_gate->Slice(bstart, bend);
|
||||||
|
gru_value.gateValue = gate_t.data<T>();
|
||||||
|
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
|
||||||
|
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
|
||||||
|
|
||||||
|
Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
|
||||||
|
gru_grad.outputGrad = hidden_grad_t.data<T>();
|
||||||
|
Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
|
||||||
|
gru_grad.gateGrad = gate_grad_t.data<T>();
|
||||||
|
Tensor reset_hidden_prev_grad_t =
|
||||||
|
batch_reset_hidden_prev_grad.Slice(bstart, bend);
|
||||||
|
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
|
||||||
|
if (n == 0) {
|
||||||
|
gru_value.prevOutValue = const_cast<T*>(h0_data);
|
||||||
|
if (h0_grad) {
|
||||||
|
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
|
||||||
|
gru_grad.prevOutGrad = h0_grad_data;
|
||||||
|
} else {
|
||||||
|
gru_grad.prevOutGrad = nullptr;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
|
||||||
|
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
|
||||||
|
gru_value.prevOutValue = hidden_prev_t.data<T>();
|
||||||
|
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
|
||||||
|
gru_grad.prevOutGrad = hidden_prev_grad_t.data<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
math::GRUUnitGradFunctor<Place, T>::compute(
|
||||||
|
context.device_context(), gru_value, gru_grad, frame_size,
|
||||||
|
cur_batch_size,
|
||||||
|
math::ActiveType(context.Attr<std::string>("activation")),
|
||||||
|
math::ActiveType(context.Attr<std::string>("gate_activation")));
|
||||||
|
}
|
||||||
|
if (input_grad) {
|
||||||
|
input_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
math::Batch2LoDTensorFunctor<Place, T> to_seq;
|
||||||
|
batch_gate_grad.set_lod(batch_gate->lod());
|
||||||
|
to_seq(context.device_context(), batch_gate_grad, *input_grad);
|
||||||
|
}
|
||||||
|
if (bias_grad) {
|
||||||
|
bias_grad->mutable_data<T>(context.GetPlace());
|
||||||
|
auto d_b = EigenMatrix<T>::From(*bias_grad);
|
||||||
|
auto d_g = EigenMatrix<T>::From(batch_gate_grad);
|
||||||
|
auto place = context.GetEigenDevice<Place>();
|
||||||
|
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
BatchCompute(context);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,203 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <type_traits>
|
||||||
|
#include "paddle/operators/math/detail/activation_functions.h"
|
||||||
|
#include "paddle/operators/math/gru_compute.h"
|
||||||
|
#include "paddle/platform/cuda_helper.h"
|
||||||
|
#include "paddle/platform/device_context.h"
|
||||||
|
|
||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* threads(framePerBlock, batchPerBlock)
|
||||||
|
* grid(frameBlocks, batchBlocks)
|
||||||
|
*/
|
||||||
|
template <class OpResetOutput, bool isBatch, typename T>
|
||||||
|
__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
|
||||||
|
T *gateValue, T *resetOutputValue,
|
||||||
|
T *prevOutputValue, int frameSize,
|
||||||
|
int batchSize,
|
||||||
|
activation_mode_t active_gate) {
|
||||||
|
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (frameIdx >= frameSize) return;
|
||||||
|
|
||||||
|
int batchIdx = 0;
|
||||||
|
if (isBatch) {
|
||||||
|
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
if (batchIdx >= batchSize) return;
|
||||||
|
gateValue += batchIdx * 3 * frameSize;
|
||||||
|
resetOutputValue += batchIdx * frameSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
T rPrevOut = 0;
|
||||||
|
T rValueResetOutput;
|
||||||
|
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
|
||||||
|
T rValueResetGate = gateValue[frameIdx + frameSize * 1];
|
||||||
|
|
||||||
|
if (prevOutputValue) {
|
||||||
|
if (isBatch) prevOutputValue += batchIdx * frameSize;
|
||||||
|
rPrevOut = prevOutputValue[frameIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput,
|
||||||
|
active_gate);
|
||||||
|
|
||||||
|
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate;
|
||||||
|
gateValue[frameIdx + frameSize * 1] = rValueResetGate;
|
||||||
|
resetOutputValue[frameIdx] = rValueResetOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* threads(framePerBlock, batchPerBlock)
|
||||||
|
* grid(frameBlocks, batchBlocks)
|
||||||
|
*/
|
||||||
|
template <class OpFinalOutput, bool isBatch, typename T>
|
||||||
|
__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
|
||||||
|
T *gateValue, T *prevOutputValue,
|
||||||
|
T *outputValue, int frameSize,
|
||||||
|
int batchSize,
|
||||||
|
activation_mode_t active_node) {
|
||||||
|
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (frameIdx >= frameSize) return;
|
||||||
|
int batchIdx = 0;
|
||||||
|
if (isBatch) {
|
||||||
|
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
if (batchIdx >= batchSize) return;
|
||||||
|
gateValue += batchIdx * 3 * frameSize;
|
||||||
|
outputValue += batchIdx * frameSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
T rOutput;
|
||||||
|
T rPrevOut = 0;
|
||||||
|
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
|
||||||
|
T rValueFrameState = gateValue[frameIdx + frameSize * 2];
|
||||||
|
|
||||||
|
if (prevOutputValue) {
|
||||||
|
if (isBatch) prevOutputValue += batchIdx * frameSize;
|
||||||
|
rPrevOut = prevOutputValue[frameIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
|
||||||
|
active_node);
|
||||||
|
|
||||||
|
gateValue[frameIdx + frameSize * 2] = rValueFrameState;
|
||||||
|
outputValue[frameIdx] = rOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* threads(framePerBlock, batchPerBlock)
|
||||||
|
* grid(frameBlocks, batchBlocks)
|
||||||
|
*/
|
||||||
|
template <class OpStateGrad, bool isBatch, typename T>
|
||||||
|
__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
|
||||||
|
T *gateGrad, T *prevOutValue,
|
||||||
|
T *prevOutGrad, T *outputGrad,
|
||||||
|
int frameSize, int batchSize,
|
||||||
|
activation_mode_t active_node) {
|
||||||
|
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (frameIdx >= frameSize) return;
|
||||||
|
int batchIdx = 0;
|
||||||
|
if (isBatch) {
|
||||||
|
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
if (batchIdx >= batchSize) return;
|
||||||
|
gateValue += batchIdx * 3 * frameSize;
|
||||||
|
gateGrad += batchIdx * 3 * frameSize;
|
||||||
|
outputGrad += batchIdx * frameSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
T rUpdateGateGrad;
|
||||||
|
T rFrameStateGrad;
|
||||||
|
T rPrevOutValue = 0;
|
||||||
|
T rPrevOutGrad = 0;
|
||||||
|
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
|
||||||
|
T rFrameStateValue = gateValue[frameIdx + frameSize * 2];
|
||||||
|
T rOutGrad = outputGrad[frameIdx];
|
||||||
|
|
||||||
|
if (prevOutValue && prevOutGrad) {
|
||||||
|
if (isBatch) prevOutValue += batchIdx * frameSize;
|
||||||
|
rPrevOutValue = prevOutValue[frameIdx];
|
||||||
|
|
||||||
|
if (isBatch) prevOutGrad += batchIdx * frameSize;
|
||||||
|
rPrevOutGrad = prevOutGrad[frameIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
|
||||||
|
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
|
||||||
|
active_node);
|
||||||
|
|
||||||
|
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
|
||||||
|
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad;
|
||||||
|
if (prevOutGrad) {
|
||||||
|
prevOutGrad[frameIdx] = rPrevOutGrad;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* threads(framePerBlock, batchPerBlock)
|
||||||
|
* grid(frameBlocks, batchBlocks)
|
||||||
|
*/
|
||||||
|
template <class OpResetGrad, bool isBatch, typename T>
|
||||||
|
__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
|
||||||
|
T *gateGrad, T *prevOutValue,
|
||||||
|
T *prevOutGrad, T *resetOutputGrad,
|
||||||
|
int frameSize, int batchSize,
|
||||||
|
activation_mode_t active_gate) {
|
||||||
|
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (frameIdx >= frameSize) return;
|
||||||
|
int batchIdx = 0;
|
||||||
|
if (isBatch) {
|
||||||
|
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
if (batchIdx >= batchSize) return;
|
||||||
|
gateValue += batchIdx * 3 * frameSize;
|
||||||
|
gateGrad += batchIdx * 3 * frameSize;
|
||||||
|
resetOutputGrad += batchIdx * frameSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
T rResetGateGrad;
|
||||||
|
T rPrevOutValue = 0;
|
||||||
|
T rPrevOutGrad = 0;
|
||||||
|
T rResetOutputGrad = 0;
|
||||||
|
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
|
||||||
|
T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0];
|
||||||
|
T rResetGateValue = gateValue[frameIdx + frameSize * 1];
|
||||||
|
|
||||||
|
if (prevOutValue && prevOutGrad) {
|
||||||
|
if (isBatch) prevOutValue += batchIdx * frameSize;
|
||||||
|
if (isBatch) prevOutGrad += batchIdx * frameSize;
|
||||||
|
rPrevOutValue = prevOutValue[frameIdx];
|
||||||
|
rPrevOutGrad = prevOutGrad[frameIdx];
|
||||||
|
rResetOutputGrad = resetOutputGrad[frameIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
|
||||||
|
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
|
||||||
|
active_gate);
|
||||||
|
|
||||||
|
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
|
||||||
|
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad;
|
||||||
|
if (prevOutGrad) {
|
||||||
|
prevOutGrad[frameIdx] = rPrevOutGrad;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,155 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/detail/activation_functions.h"
|
||||||
|
#include "paddle/platform/hostdevice.h"
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
// TODO(guosheng): refine code style in gru_kernel
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
namespace forward {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class gru_resetOutput {
|
||||||
|
public:
|
||||||
|
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut,
|
||||||
|
T &valueResetOutput, activation_mode_t actGate) {
|
||||||
|
valueUpdateGate = activation(valueUpdateGate, actGate);
|
||||||
|
valueResetGate = activation(valueResetGate, actGate);
|
||||||
|
valueResetOutput = prevOut * valueResetGate;
|
||||||
|
}
|
||||||
|
#ifndef __NVCC__
|
||||||
|
#ifndef __AVX__
|
||||||
|
static const bool avx = false;
|
||||||
|
#else
|
||||||
|
static const bool avx = true;
|
||||||
|
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate,
|
||||||
|
__m256 &prevOut, __m256 &valueResetOutput,
|
||||||
|
activation_mode_t actGate) {
|
||||||
|
valueUpdateGate = activation(valueUpdateGate, actGate);
|
||||||
|
valueResetGate = activation(valueResetGate, actGate);
|
||||||
|
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class gru_finalOutput {
|
||||||
|
public:
|
||||||
|
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut,
|
||||||
|
T &valueOutput, activation_mode_t actInput) {
|
||||||
|
valueFrameState = activation(valueFrameState, actInput);
|
||||||
|
valueOutput = prevOut - (valueUpdateGate * prevOut) +
|
||||||
|
(valueUpdateGate * valueFrameState);
|
||||||
|
}
|
||||||
|
#ifndef __NVCC__
|
||||||
|
#ifndef __AVX__
|
||||||
|
static const bool avx = false;
|
||||||
|
#else
|
||||||
|
static const bool avx = true;
|
||||||
|
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState,
|
||||||
|
__m256 &prevOut, __m256 &valueOutput,
|
||||||
|
activation_mode_t actInput) {
|
||||||
|
valueFrameState = activation(valueFrameState, actInput);
|
||||||
|
valueOutput = _mm256_add_ps(
|
||||||
|
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)),
|
||||||
|
_mm256_mul_ps(valueUpdateGate, valueFrameState));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
} // namespace forward
|
||||||
|
|
||||||
|
namespace backward {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class gru_stateGrad {
|
||||||
|
public:
|
||||||
|
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
|
||||||
|
T &valueFrameState, T &gradFrameState,
|
||||||
|
T &valuePrevOut, T &gradPrevOut, T &gradOutput,
|
||||||
|
activation_mode_t actInput) {
|
||||||
|
gradUpdateGate = (gradOutput * valueFrameState);
|
||||||
|
gradUpdateGate -= (gradOutput * valuePrevOut);
|
||||||
|
gradPrevOut -= (gradOutput * valueUpdateGate);
|
||||||
|
gradPrevOut += gradOutput;
|
||||||
|
gradFrameState =
|
||||||
|
activation(gradOutput * valueUpdateGate, valueFrameState, actInput);
|
||||||
|
}
|
||||||
|
#ifndef __NVCC__
|
||||||
|
#ifndef __AVX__
|
||||||
|
static const bool avx = false;
|
||||||
|
#else
|
||||||
|
static const bool avx = true;
|
||||||
|
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
|
||||||
|
__m256 &valueFrameState, __m256 &gradFrameState,
|
||||||
|
__m256 &valuePrevOut, __m256 &gradPrevOut,
|
||||||
|
__m256 &gradOutput, activation_mode_t actInput) {
|
||||||
|
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState);
|
||||||
|
gradUpdateGate =
|
||||||
|
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut));
|
||||||
|
gradPrevOut = _mm256_add_ps(
|
||||||
|
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)),
|
||||||
|
gradOutput);
|
||||||
|
gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate),
|
||||||
|
valueFrameState, actInput);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class gru_resetGrad {
|
||||||
|
public:
|
||||||
|
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
|
||||||
|
T &valueResetGate, T &gradResetGate,
|
||||||
|
T &valuePrevOut, T &gradPrevOut,
|
||||||
|
T &gradResetOutput, activation_mode_t actGate) {
|
||||||
|
gradResetGate = (gradResetOutput * valuePrevOut);
|
||||||
|
gradPrevOut += (gradResetOutput * valueResetGate);
|
||||||
|
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
|
||||||
|
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
|
||||||
|
}
|
||||||
|
#ifndef __NVCC__
|
||||||
|
#ifndef __AVX__
|
||||||
|
static const bool avx = false;
|
||||||
|
#else
|
||||||
|
static const bool avx = true;
|
||||||
|
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
|
||||||
|
__m256 &valueResetGate, __m256 &gradResetGate,
|
||||||
|
__m256 &valuePrevOut, __m256 &gradPrevOut,
|
||||||
|
__m256 &gradResetOutput,
|
||||||
|
activation_mode_t actGate) {
|
||||||
|
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut);
|
||||||
|
gradPrevOut = _mm256_add_ps(gradPrevOut,
|
||||||
|
_mm256_mul_ps(gradResetOutput, valueResetGate));
|
||||||
|
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
|
||||||
|
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace backward
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,102 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/gru_compute.h"
|
||||||
|
#include "paddle/operators/math/detail/gru_cpu_kernel.h"
|
||||||
|
#include "paddle/operators/math/detail/gru_kernel.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct GRUUnitFunctor<platform::CPUPlace, T> {
|
||||||
|
static void compute(const platform::DeviceContext &context,
|
||||||
|
hl_gru_value<T> value, int frameSize, int batchSize,
|
||||||
|
activation_mode_t active_node,
|
||||||
|
activation_mode_t active_gate) {
|
||||||
|
#ifndef __NVCC__
|
||||||
|
if (value.prevOutValue) {
|
||||||
|
math::gemm<platform::CPUPlace, T>(
|
||||||
|
context, false, false, batchSize, frameSize * 2, frameSize, 1,
|
||||||
|
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
|
||||||
|
value.gateValue, frameSize * 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
|
||||||
|
frameSize, batchSize, active_gate);
|
||||||
|
|
||||||
|
if (value.prevOutValue) {
|
||||||
|
math::gemm<platform::CPUPlace, T>(
|
||||||
|
context, false, false, batchSize, frameSize, frameSize, 1,
|
||||||
|
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
|
||||||
|
value.gateValue + frameSize * 2, frameSize * 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
|
||||||
|
frameSize, batchSize, active_node);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct GRUUnitGradFunctor<platform::CPUPlace, T> {
|
||||||
|
static void compute(const platform::DeviceContext &context,
|
||||||
|
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
|
||||||
|
int batchSize, activation_mode_t active_node,
|
||||||
|
activation_mode_t active_gate) {
|
||||||
|
#ifndef __NVCC__
|
||||||
|
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
|
||||||
|
grad, frameSize, batchSize, active_node);
|
||||||
|
|
||||||
|
if (value.prevOutValue && grad.prevOutGrad) {
|
||||||
|
math::gemm<platform::CPUPlace, T>(
|
||||||
|
context, false, true, batchSize, frameSize, frameSize, 1,
|
||||||
|
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
|
||||||
|
frameSize, 0, grad.resetOutputGrad, frameSize);
|
||||||
|
|
||||||
|
if (grad.stateWeightGrad) {
|
||||||
|
math::gemm<platform::CPUPlace, T>(
|
||||||
|
context, true, false, frameSize, frameSize, batchSize, 1,
|
||||||
|
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
|
||||||
|
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
|
||||||
|
grad, frameSize, batchSize, active_gate);
|
||||||
|
|
||||||
|
if (grad.prevOutGrad && value.prevOutValue) {
|
||||||
|
math::gemm<platform::CPUPlace, T>(
|
||||||
|
context, false, true, batchSize, frameSize, frameSize * 2, 1,
|
||||||
|
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
|
||||||
|
grad.prevOutGrad, frameSize);
|
||||||
|
|
||||||
|
if (grad.gateWeightGrad) {
|
||||||
|
math::gemm<platform::CPUPlace, T>(
|
||||||
|
context, true, false, frameSize, frameSize * 2, batchSize, 1,
|
||||||
|
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
|
||||||
|
grad.gateWeightGrad, frameSize * 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template struct GRUUnitFunctor<platform::CPUPlace, float>;
|
||||||
|
template struct GRUUnitFunctor<platform::CPUPlace, double>;
|
||||||
|
template struct GRUUnitGradFunctor<platform::CPUPlace, float>;
|
||||||
|
template struct GRUUnitGradFunctor<platform::CPUPlace, double>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,178 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/detail/gru_gpu_kernel.h"
|
||||||
|
#include "paddle/operators/math/detail/gru_kernel.h"
|
||||||
|
#include "paddle/operators/math/gru_compute.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct GRUUnitFunctor<platform::GPUPlace, T> {
|
||||||
|
static void compute(const platform::DeviceContext &context,
|
||||||
|
hl_gru_value<T> value, int frameSize, int batchSize,
|
||||||
|
activation_mode_t active_node,
|
||||||
|
activation_mode_t active_gate) {
|
||||||
|
auto stream =
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
|
||||||
|
dim3 threads;
|
||||||
|
dim3 grid;
|
||||||
|
if (batchSize == 1) {
|
||||||
|
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
|
||||||
|
int frameBlocks = (frameSize + 1024 - 1) / 1024;
|
||||||
|
threads = dim3(framePerBlock, 1);
|
||||||
|
grid = dim3(frameBlocks, 1);
|
||||||
|
} else {
|
||||||
|
threads = dim3(32, 32);
|
||||||
|
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value.prevOutValue) {
|
||||||
|
math::gemm<platform::GPUPlace, T>(
|
||||||
|
context, false, false, batchSize, frameSize * 2, frameSize, 1,
|
||||||
|
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
|
||||||
|
value.gateValue, frameSize * 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batchSize == 1) {
|
||||||
|
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
|
||||||
|
/* isBatch= */ false,
|
||||||
|
T><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::forward::gru_resetOutput<T>(), value.gateValue,
|
||||||
|
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
|
||||||
|
active_gate);
|
||||||
|
} else {
|
||||||
|
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
|
||||||
|
/* isBatch= */ true,
|
||||||
|
T><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::forward::gru_resetOutput<T>(), value.gateValue,
|
||||||
|
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
|
||||||
|
active_gate);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value.prevOutValue) {
|
||||||
|
math::gemm<platform::GPUPlace, T>(
|
||||||
|
context, false, false, batchSize, frameSize, frameSize, 1,
|
||||||
|
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
|
||||||
|
value.gateValue + frameSize * 2, frameSize * 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batchSize == 1) {
|
||||||
|
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
|
||||||
|
/* isBatch= */ false,
|
||||||
|
T><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::forward::gru_finalOutput<T>(), value.gateValue,
|
||||||
|
value.prevOutValue, value.outputValue, frameSize, batchSize,
|
||||||
|
active_node);
|
||||||
|
} else {
|
||||||
|
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
|
||||||
|
/* isBatch= */ true,
|
||||||
|
T><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::forward::gru_finalOutput<T>(), value.gateValue,
|
||||||
|
value.prevOutValue, value.outputValue, frameSize, batchSize,
|
||||||
|
active_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct GRUUnitGradFunctor<platform::GPUPlace, T> {
|
||||||
|
static void compute(const platform::DeviceContext &context,
|
||||||
|
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
|
||||||
|
int batchSize, activation_mode_t active_node,
|
||||||
|
activation_mode_t active_gate) {
|
||||||
|
auto stream =
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
|
||||||
|
dim3 threads;
|
||||||
|
dim3 grid;
|
||||||
|
if (batchSize == 1) {
|
||||||
|
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
|
||||||
|
int frameBlocks = (frameSize + 1024 - 1) / 1024;
|
||||||
|
threads = dim3(framePerBlock, 1);
|
||||||
|
grid = dim3(frameBlocks, 1);
|
||||||
|
} else {
|
||||||
|
threads = dim3(32, 32);
|
||||||
|
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batchSize == 1) {
|
||||||
|
detail::KeGruBackwardStateGrad<
|
||||||
|
detail::backward::gru_stateGrad<T>,
|
||||||
|
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
|
||||||
|
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
|
||||||
|
batchSize, active_node);
|
||||||
|
} else {
|
||||||
|
detail::KeGruBackwardStateGrad<
|
||||||
|
detail::backward::gru_stateGrad<T>,
|
||||||
|
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
|
||||||
|
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
|
||||||
|
batchSize, active_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value.prevOutValue && grad.prevOutGrad) {
|
||||||
|
math::gemm<platform::GPUPlace, T>(
|
||||||
|
context, false, true, batchSize, frameSize, frameSize, 1,
|
||||||
|
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
|
||||||
|
frameSize, 0, grad.resetOutputGrad, frameSize);
|
||||||
|
|
||||||
|
if (grad.stateWeightGrad) {
|
||||||
|
math::gemm<platform::GPUPlace, T>(
|
||||||
|
context, true, false, frameSize, frameSize, batchSize, 1,
|
||||||
|
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
|
||||||
|
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batchSize == 1) {
|
||||||
|
detail::KeGruBackwardResetGrad<
|
||||||
|
detail::backward::gru_resetGrad<T>,
|
||||||
|
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
|
||||||
|
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
|
||||||
|
batchSize, active_gate);
|
||||||
|
} else {
|
||||||
|
detail::KeGruBackwardResetGrad<
|
||||||
|
detail::backward::gru_resetGrad<T>,
|
||||||
|
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
|
||||||
|
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
|
||||||
|
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
|
||||||
|
batchSize, active_gate);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad.prevOutGrad && value.prevOutValue) {
|
||||||
|
math::gemm<platform::GPUPlace, T>(
|
||||||
|
context, false, true, batchSize, frameSize, frameSize * 2, 1,
|
||||||
|
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
|
||||||
|
grad.prevOutGrad, frameSize);
|
||||||
|
|
||||||
|
if (grad.gateWeightGrad) {
|
||||||
|
math::gemm<platform::GPUPlace, T>(
|
||||||
|
context, true, false, frameSize, frameSize * 2, batchSize, 1,
|
||||||
|
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
|
||||||
|
grad.gateWeightGrad, frameSize * 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template struct GRUUnitFunctor<platform::GPUPlace, float>;
|
||||||
|
template struct GRUUnitFunctor<platform::GPUPlace, double>;
|
||||||
|
template struct GRUUnitGradFunctor<platform::GPUPlace, float>;
|
||||||
|
template struct GRUUnitGradFunctor<platform::GPUPlace, double>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,61 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/operators/math/lstm_compute.h"
|
||||||
|
#include "paddle/platform/device_context.h"
|
||||||
|
#include "paddle/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
// TODO(guosheng): refine code style in gru_compute
|
||||||
|
template <typename T>
|
||||||
|
struct hl_gru_value {
|
||||||
|
T *gateWeight;
|
||||||
|
T *stateWeight;
|
||||||
|
T *gateValue;
|
||||||
|
T *resetOutputValue;
|
||||||
|
T *outputValue;
|
||||||
|
T *prevOutValue;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct hl_gru_grad {
|
||||||
|
T *gateWeightGrad;
|
||||||
|
T *stateWeightGrad;
|
||||||
|
T *gateGrad;
|
||||||
|
T *resetOutputGrad;
|
||||||
|
T *outputGrad;
|
||||||
|
T *prevOutGrad;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
struct GRUUnitFunctor {
|
||||||
|
static void compute(const platform::DeviceContext &context,
|
||||||
|
hl_gru_value<T> value, int frameSize, int batchSize,
|
||||||
|
activation_mode_t active_node,
|
||||||
|
activation_mode_t active_gate);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
struct GRUUnitGradFunctor {
|
||||||
|
static void compute(const platform::DeviceContext &context,
|
||||||
|
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
|
||||||
|
int batchSize, activation_mode_t active_node,
|
||||||
|
activation_mode_t active_gate);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,156 @@
|
|||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from op_test import OpTest
|
||||||
|
from test_lstm_op import identity, sigmoid, tanh, relu
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRUOp(OpTest):
|
||||||
|
batch_size = 9
|
||||||
|
frame_size = 5
|
||||||
|
activate = {
|
||||||
|
'identity': identity,
|
||||||
|
'sigmoid': sigmoid,
|
||||||
|
'tanh': tanh,
|
||||||
|
'relu': relu
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def seq_to_batch(lod, is_reverse):
|
||||||
|
idx_in_seq_list = []
|
||||||
|
seq_starts = lod[0]
|
||||||
|
seq_lens = []
|
||||||
|
for i in range(len(seq_starts) - 1):
|
||||||
|
seq_lens.append(seq_starts[i + 1] - seq_starts[i])
|
||||||
|
sorted_seqs = sorted(
|
||||||
|
range(len(seq_lens)), lambda x, y: seq_lens[y] - seq_lens[x])
|
||||||
|
num_batch = seq_lens[sorted_seqs[0]]
|
||||||
|
for batch_idx in range(num_batch):
|
||||||
|
idx_in_seq = []
|
||||||
|
for i in range(len(seq_lens)):
|
||||||
|
if seq_lens[sorted_seqs[i]] <= batch_idx:
|
||||||
|
break
|
||||||
|
idx = (seq_starts[sorted_seqs[i] + 1] - 1 - batch_idx
|
||||||
|
) if is_reverse else (
|
||||||
|
seq_starts[sorted_seqs[i]] + batch_idx)
|
||||||
|
idx_in_seq.append(idx)
|
||||||
|
idx_in_seq_list.append(idx_in_seq)
|
||||||
|
return idx_in_seq_list
|
||||||
|
|
||||||
|
def gru_step(self, x, h_p, w, b):
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
frame_size = w.shape[0]
|
||||||
|
g = x + np.tile(b, (batch_size, 1))
|
||||||
|
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
|
||||||
|
(frame_size, frame_size * 2))
|
||||||
|
u_r = self.activate[self.attrs['gate_activation']](np.dot(
|
||||||
|
h_p, w_u_r) + g[:, :frame_size * 2])
|
||||||
|
u = u_r[:, :frame_size]
|
||||||
|
r = u_r[:, frame_size:frame_size * 2]
|
||||||
|
r_h_p = r * h_p
|
||||||
|
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
|
||||||
|
(frame_size, frame_size))
|
||||||
|
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
|
||||||
|
g[:, frame_size * 2:])
|
||||||
|
g = np.hstack((u_r, c))
|
||||||
|
h = u * c + (1 - u) * h_p
|
||||||
|
return g, r_h_p, h
|
||||||
|
|
||||||
|
def gru(self):
|
||||||
|
input, lod = self.inputs['Input']
|
||||||
|
w = self.inputs['Weight']
|
||||||
|
b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros(
|
||||||
|
(1, self.frame_size * 3))
|
||||||
|
batch_gate = self.outputs['BatchGate']
|
||||||
|
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev']
|
||||||
|
batch_hidden = self.outputs['BatchHidden']
|
||||||
|
hidden = self.outputs['Hidden']
|
||||||
|
idx_in_seq_list = self.idx_in_seq_list
|
||||||
|
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
|
||||||
|
(len(idx_in_seq_list[0]), self.frame_size))
|
||||||
|
num_batch = len(idx_in_seq_list)
|
||||||
|
end_idx = 0
|
||||||
|
for batch_idx in range(num_batch):
|
||||||
|
x = input[idx_in_seq_list[batch_idx]]
|
||||||
|
g, r_h_p, h = self.gru_step(x, h_p, w, b)
|
||||||
|
if batch_idx < (num_batch - 1):
|
||||||
|
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
|
||||||
|
start_idx = end_idx
|
||||||
|
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
|
||||||
|
batch_gate[start_idx:end_idx] = g
|
||||||
|
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
|
||||||
|
batch_hidden[start_idx:end_idx] = h
|
||||||
|
hidden[idx_in_seq_list[batch_idx]] = h
|
||||||
|
return batch_gate, batch_reset_hidden_prev, hidden
|
||||||
|
|
||||||
|
def set_data(self):
|
||||||
|
lod = [[0, 2, 6, self.batch_size]]
|
||||||
|
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
|
||||||
|
batch_size = self.batch_size
|
||||||
|
frame_size = self.frame_size
|
||||||
|
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
|
||||||
|
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
|
||||||
|
frame_size).astype('float64')
|
||||||
|
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
|
||||||
|
bias = np.random.rand(1, frame_size * 3).astype('float64')
|
||||||
|
|
||||||
|
self.inputs = {
|
||||||
|
'Input': (input, lod),
|
||||||
|
'H0': h0,
|
||||||
|
'Weight': weight,
|
||||||
|
'Bias': bias
|
||||||
|
}
|
||||||
|
|
||||||
|
self.outputs = {
|
||||||
|
'BatchGate': np.zeros(
|
||||||
|
(batch_size, frame_size * 3), dtype='float64'),
|
||||||
|
'BatchResetHiddenPrev': np.zeros(
|
||||||
|
(batch_size, frame_size), dtype='float64'),
|
||||||
|
'BatchHidden': np.zeros(
|
||||||
|
(batch_size, frame_size), dtype='float64'),
|
||||||
|
'Hidden': np.zeros(
|
||||||
|
(batch_size, frame_size), dtype='float64')
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_confs(self):
|
||||||
|
self.is_reverse = False
|
||||||
|
self.attrs = {
|
||||||
|
'activation': 'tanh',
|
||||||
|
'gate_activation': 'sigmoid',
|
||||||
|
'is_reverse': self.is_reverse
|
||||||
|
}
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "gru"
|
||||||
|
self.set_confs()
|
||||||
|
self.set_data()
|
||||||
|
self.gru()
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRUOpNoInitial(TestGRUOp):
|
||||||
|
def set_data(self):
|
||||||
|
super(TestGRUOpNoInitial, self).set_data()
|
||||||
|
self.inputs.pop('H0')
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRUOpReverse(TestGRUOp):
|
||||||
|
def set_confs(self):
|
||||||
|
self.is_reverse = True
|
||||||
|
self.attrs = {
|
||||||
|
'activation': 'identity',
|
||||||
|
'gate_activation': 'sigmoid',
|
||||||
|
'is_reverse': self.is_reverse
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue