Merge pull request #4443 from guoshengCS/add-GRUStepOp
Add gru_unit_oprevert-4814-Add_sequence_project_op
commit
a0af1eeabf
@ -0,0 +1,210 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/gru_unit_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class GRUUnitOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(%s) of GRUUnitOp should not be null.", "Input");
|
||||
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
|
||||
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
||||
"Input(%s) of GRUUnitOp should not be null.", "Weight");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Gate"),
|
||||
"Output(%s) of GRUUnitOp should not be null.", "Gate");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
|
||||
"Output(%s) of GRUUnitOp should not be null.",
|
||||
"ResetHiddenPrev");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
||||
"Output(%s) of GRUUnitOp should not be null.", "Hidden");
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
|
||||
auto weight_dims = ctx->GetInputDim("Weight");
|
||||
int batch_size = input_dims[0];
|
||||
int input_size = input_dims[1];
|
||||
int frame_size = hidden_prev_dims[1];
|
||||
int weight_height = weight_dims[0];
|
||||
int weight_width = weight_dims[1];
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_size, frame_size * 3,
|
||||
"The input_size must be 3 times of frame_size in GRUUnitOp.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
weight_height, frame_size,
|
||||
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
weight_width, frame_size * 3,
|
||||
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||
auto bias = Input("Bias");
|
||||
if (bias != framework::kEmptyVarName) {
|
||||
auto bias_dims = ctx->GetInputDim("Bias");
|
||||
int bias_height = bias_dims[0];
|
||||
int bias_width = bias_dims[1];
|
||||
PADDLE_ENFORCE_EQ(bias_height, 1,
|
||||
"The shape of Bias must be [1, frame_size * 3].");
|
||||
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
||||
"The shape of Bias must be [1, frame_size * 3].");
|
||||
}
|
||||
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
|
||||
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
|
||||
ctx->SetOutputDim("Hidden", {batch_size, frame_size});
|
||||
}
|
||||
};
|
||||
|
||||
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
GRUUnitOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input",
|
||||
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
|
||||
"input.");
|
||||
AddInput("HiddenPrev",
|
||||
"(Tensor) Matrix with shape [batch_size, frame_size] for the "
|
||||
"states of previous time step.");
|
||||
AddInput("Weight",
|
||||
"(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. "
|
||||
"The elements continuous in memory can be divided into two parts. "
|
||||
"The first part are weights of the update gate and reset gate "
|
||||
"with shape [frame_size, frame_size * 2], and the second part are "
|
||||
"weights of output candidate with shape [frame_size, frame_size]");
|
||||
AddInput("Bias",
|
||||
"(Tensor) Bias vector with shape [1, frame_size * 3] concating "
|
||||
"bias of the update gate, reset gate and output candidate.");
|
||||
AddOutput("Gate",
|
||||
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
|
||||
"output of update gate, reset gate and output candidate")
|
||||
.AsIntermediate();
|
||||
AddOutput("ResetHiddenPrev",
|
||||
"(Tensor) Matrix with shape [batch_size, frame_size] for the "
|
||||
"reseted hidden state of previous time step.")
|
||||
.AsIntermediate();
|
||||
AddOutput("Hidden",
|
||||
"(Tensor) The GRU hidden state of the current time step "
|
||||
"with shape [batch_size, frame_size].");
|
||||
AddAttr<int>("activation",
|
||||
"(enum int, default tanh) "
|
||||
"The activation type used for output candidate {h}_t.")
|
||||
.SetDefault(tanh)
|
||||
.InEnum({identity, sigmoid, tanh, relu});
|
||||
AddAttr<int>("gate_activation",
|
||||
"(enum int, default sigmoid) "
|
||||
"The activation type used in update gate and reset gate.")
|
||||
.SetDefault(sigmoid)
|
||||
.InEnum({identity, sigmoid, tanh, relu});
|
||||
AddComment(R"DOC(
|
||||
GRUUnitOp implements part calculations of the GRU unit as following:
|
||||
|
||||
\f[
|
||||
update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\
|
||||
reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\
|
||||
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\
|
||||
output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_prev)
|
||||
\f]
|
||||
|
||||
The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class GRUUnitGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(%s) of GRUUnitGradOp should not be null.", "Input");
|
||||
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
|
||||
"Input(%s) of GRUUnitGradOp should not be null.",
|
||||
"HiddenPrev");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
||||
"Input(%s) of GRUUnitGradOp should not be null.", "Weight");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Gate"),
|
||||
"Input(%s) of GRUUnitGradOp should not be null.", "Gate");
|
||||
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
|
||||
"Input(%s) of GRUUnitGradOp should not be null.",
|
||||
"ResetHiddenPrev");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
|
||||
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
|
||||
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
|
||||
"Gate");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
|
||||
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
|
||||
"ResetHiddenPrev");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
||||
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
|
||||
"Hidden");
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
|
||||
auto weight_dims = ctx->GetInputDim("Weight");
|
||||
// int batch_size = input_dims[0];
|
||||
int input_size = input_dims[1];
|
||||
int frame_size = hidden_prev_dims[1];
|
||||
int weight_height = weight_dims[0];
|
||||
int weight_width = weight_dims[1];
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_size, frame_size * 3,
|
||||
"The input_size must be 3 times of frame_size in GRUUnitOp.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
weight_height, frame_size,
|
||||
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
weight_width, frame_size * 3,
|
||||
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
||||
auto bias = Input("Bias");
|
||||
if (bias != framework::kEmptyVarName) {
|
||||
auto bias_dims = ctx->GetInputDim("Bias");
|
||||
int bias_height = bias_dims[0];
|
||||
int bias_width = bias_dims[1];
|
||||
PADDLE_ENFORCE_EQ(bias_height, 1,
|
||||
"The shape of Bias must be [1, frame_size * 3].");
|
||||
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
||||
"The shape of Bias must be [1, frame_size * 3].");
|
||||
auto bias_grad_name = framework::GradVarName("Bias");
|
||||
if (ctx->HasOutput(bias_grad_name))
|
||||
ctx->SetOutputDim(bias_grad_name, bias_dims);
|
||||
}
|
||||
auto input_grad_name = framework::GradVarName("Input");
|
||||
if (ctx->HasOutput(input_grad_name))
|
||||
ctx->SetOutputDim(input_grad_name, input_dims);
|
||||
auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev");
|
||||
if (ctx->HasOutput(hidden_prev_grad_name))
|
||||
ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims);
|
||||
auto weight_grad_name = framework::GradVarName("Weight");
|
||||
if (ctx->HasOutput(weight_grad_name))
|
||||
ctx->SetOutputDim(weight_grad_name, weight_dims);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad,
|
||||
ops::GRUUnitGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(gru_unit,
|
||||
ops::GRUUnitKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/gru_unit_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(gru_unit,
|
||||
ops::GRUUnitKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,230 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/operators/activation_op.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 };
|
||||
|
||||
template <typename Place, typename T>
|
||||
class GRUUnitKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
template <typename Device, typename X, typename Y>
|
||||
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
|
||||
if (act_type == identity)
|
||||
y.device(d) = x;
|
||||
else if (act_type == sigmoid)
|
||||
SigmoidFunctor<T>()(d, x, y);
|
||||
else if (act_type == tanh)
|
||||
TanhFunctor<T>()(d, x, y);
|
||||
else if (act_type == relu)
|
||||
ReluFunctor<T>()(d, x, y);
|
||||
else
|
||||
PADDLE_THROW("unsupported activation type");
|
||||
}
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input = context.Input<Tensor>("Input");
|
||||
auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
|
||||
auto* weight = context.Input<Tensor>("Weight");
|
||||
auto* bias = context.Input<Tensor>("Bias");
|
||||
auto* gate = context.Output<Tensor>("Gate");
|
||||
gate->mutable_data<T>(context.GetPlace());
|
||||
auto* reset_hidden_prev = context.Output<Tensor>("ResetHiddenPrev");
|
||||
reset_hidden_prev->mutable_data<T>(context.GetPlace());
|
||||
auto* hidden = context.Output<Tensor>("Hidden");
|
||||
hidden->mutable_data<T>(context.GetPlace());
|
||||
|
||||
int batch_size = input->dims()[0];
|
||||
int frame_size = hidden_prev->dims()[1];
|
||||
|
||||
auto x = EigenMatrix<T>::From(*input);
|
||||
auto h_p = EigenMatrix<T>::From(*hidden_prev);
|
||||
auto g = EigenMatrix<T>::From(*gate);
|
||||
auto r_h_p = EigenMatrix<T>::From(*reset_hidden_prev);
|
||||
auto h = EigenMatrix<T>::From(*hidden);
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
// calculate unactivated gate outputs
|
||||
if (bias) {
|
||||
auto b = EigenMatrix<T>::From(*bias);
|
||||
g.device(place) = x +
|
||||
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
|
||||
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
|
||||
} else {
|
||||
g.device(place) = x;
|
||||
}
|
||||
const T* hidden_prev_data = hidden_prev->data<T>();
|
||||
const T* weight_data = weight->data<T>();
|
||||
T* gate_data = gate->data<T>();
|
||||
T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
|
||||
math::gemm<Place, T>(context.device_context(), false, false, batch_size,
|
||||
2 * frame_size, frame_size, 1, hidden_prev_data,
|
||||
frame_size, weight_data, frame_size * 2, 1, gate_data,
|
||||
frame_size * 3);
|
||||
|
||||
// calculate activited gate
|
||||
Eigen::array<int, 2> extents({{batch_size, frame_size}});
|
||||
Eigen::array<int, 2> u_offsets({{0, 0}});
|
||||
ActCompute(context.Attr<int>("gate_activation"), place,
|
||||
g.slice(u_offsets, extents), g.slice(u_offsets, extents));
|
||||
auto u = g.slice(u_offsets, extents); // update gate
|
||||
Eigen::array<int, 2> r_offsets({{0, frame_size}});
|
||||
ActCompute(context.Attr<int>("gate_activation"), place,
|
||||
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
|
||||
auto r = g.slice(r_offsets, extents); // reset gate
|
||||
r_h_p.device(place) = r * h_p; // reset previous hidden state
|
||||
math::gemm<Place, T>(context.device_context(), false, false, batch_size,
|
||||
frame_size, frame_size, 1, reset_hidden_prev_data,
|
||||
frame_size, weight_data + frame_size * frame_size * 2,
|
||||
frame_size, 1, gate_data + frame_size * 2,
|
||||
frame_size * 3);
|
||||
|
||||
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
|
||||
ActCompute(context.Attr<int>("activation"), place,
|
||||
g.slice(c_offsets, extents), g.slice(c_offsets, extents));
|
||||
auto c = g.slice(c_offsets, extents); // output candidate
|
||||
|
||||
// calculate final output
|
||||
h.device(place) = u * (h_p - c) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class GRUUnitGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
template <typename Device, typename X, typename Y, typename DX, typename DY>
|
||||
void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx,
|
||||
DY dy) const {
|
||||
// x is dummy and won't be used even in Relu(use y instead)
|
||||
if (act_type == identity)
|
||||
dx.device(d) = dy;
|
||||
else if (act_type == sigmoid)
|
||||
SigmoidGradFunctor<T>()(d, x, y, dy, dx);
|
||||
else if (act_type == tanh)
|
||||
TanhGradFunctor<T>()(d, x, y, dy, dx);
|
||||
else if (act_type == relu)
|
||||
ReluGradFunctor<T>()(d, x, y, dy, dx);
|
||||
else
|
||||
PADDLE_THROW("unsupported activation type");
|
||||
}
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input = context.Input<Tensor>("Input");
|
||||
auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
|
||||
auto* weight = context.Input<Tensor>("Weight");
|
||||
auto* gate = context.Input<Tensor>("Gate");
|
||||
auto* reset_hidden_prev = context.Input<Tensor>("ResetHiddenPrev");
|
||||
auto* hidden_grad = context.Input<Tensor>(framework::GradVarName("Hidden"));
|
||||
auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
|
||||
auto* hidden_prev_grad =
|
||||
context.Output<Tensor>(framework::GradVarName("HiddenPrev"));
|
||||
auto* weight_grad =
|
||||
context.Output<Tensor>(framework::GradVarName("Weight"));
|
||||
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
input_grad->mutable_data<T>(context.GetPlace());
|
||||
hidden_prev_grad->mutable_data<T>(context.GetPlace());
|
||||
weight_grad->mutable_data<T>(context.GetPlace());
|
||||
Tensor gate_grad;
|
||||
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
|
||||
Tensor reset_hidden_prev_grad;
|
||||
reset_hidden_prev_grad.mutable_data<T>(reset_hidden_prev->dims(),
|
||||
context.GetPlace());
|
||||
|
||||
int batch_size = input->dims()[0];
|
||||
int frame_size = hidden_prev->dims()[1];
|
||||
|
||||
const T* hidden_prev_data = hidden_prev->data<T>();
|
||||
T* hidden_prev_grad_data = hidden_prev_grad->data<T>();
|
||||
const T* weight_data = weight->data<T>();
|
||||
T* weight_grad_data = weight_grad->data<T>();
|
||||
T* gate_grad_data = gate_grad.data<T>();
|
||||
const T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
|
||||
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data<T>();
|
||||
|
||||
auto h_p = EigenMatrix<T>::From(*hidden_prev);
|
||||
auto g = EigenMatrix<T>::From(*gate);
|
||||
auto d_h = EigenMatrix<T>::From(*hidden_grad);
|
||||
auto d_x = EigenMatrix<T>::From(*input_grad);
|
||||
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
|
||||
auto d_g = EigenMatrix<T>::From(gate_grad);
|
||||
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
Eigen::array<int, 2> extents({{batch_size, frame_size}});
|
||||
Eigen::array<int, 2> u_offsets({{0, 0}});
|
||||
auto u = g.slice(u_offsets, extents); // update gate
|
||||
Eigen::array<int, 2> r_offsets({{0, frame_size}});
|
||||
auto r = g.slice(r_offsets, extents); // reset gate
|
||||
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
|
||||
auto c = g.slice(c_offsets, extents); // output candidate
|
||||
|
||||
// backward for unactivated update gate
|
||||
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
|
||||
d_g.slice(u_offsets, extents), d_h * (h_p - c));
|
||||
// backward for unactivated output candidate
|
||||
ActGradCompute(context.Attr<int>("activation"), place, c, c,
|
||||
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
|
||||
// backward for reset_hidden_prev
|
||||
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
|
||||
frame_size, frame_size, 1,
|
||||
gate_grad_data + frame_size * 2, frame_size * 3,
|
||||
weight_data + frame_size * frame_size * 2, frame_size,
|
||||
0, reset_hidden_prev_grad_data, frame_size);
|
||||
// backward for state_weight
|
||||
math::gemm<Place, T>(
|
||||
context.device_context(), true, false, frame_size, frame_size,
|
||||
batch_size, 1, reset_hidden_prev_data, frame_size,
|
||||
gate_grad_data + frame_size * 2, frame_size * 3, 0,
|
||||
weight_grad_data + frame_size * frame_size * 2, frame_size);
|
||||
// backward for unactivated reset gate
|
||||
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
|
||||
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
|
||||
// backward for update_gate_weight and reset_gate_weight
|
||||
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
|
||||
frame_size * 2, batch_size, 1, hidden_prev_data,
|
||||
frame_size, gate_grad_data, frame_size * 3, 0,
|
||||
weight_grad_data, frame_size * 2);
|
||||
// backward for hidden_prev
|
||||
d_h_p.device(place) = d_r_h_p * r + d_h * u;
|
||||
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
|
||||
frame_size, frame_size * 2, 1, gate_grad_data,
|
||||
frame_size * 3, weight_data, frame_size * 2, 1,
|
||||
hidden_prev_grad_data, frame_size);
|
||||
// backward for input
|
||||
d_x.device(place) = d_g;
|
||||
// backward for bias
|
||||
if (bias_grad) {
|
||||
bias_grad->mutable_data<T>(context.GetPlace());
|
||||
auto d_b = EigenMatrix<T>::From(*bias_grad);
|
||||
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,115 @@
|
||||
import math
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class GRUActivationType(OpTest):
|
||||
identity = 0
|
||||
sigmoid = 1
|
||||
tanh = 2
|
||||
relu = 3
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
return 1. / (1. + np.exp(-x))
|
||||
|
||||
|
||||
def tanh(x):
|
||||
return 2. * sigmoid(2. * x) - 1.
|
||||
|
||||
|
||||
def relu(x):
|
||||
return np.maximum(x, 0)
|
||||
|
||||
|
||||
class TestGRUUnitOp(OpTest):
|
||||
batch_size = 3
|
||||
frame_size = 5
|
||||
activate = {
|
||||
GRUActivationType.identity: identity,
|
||||
GRUActivationType.sigmoid: sigmoid,
|
||||
GRUActivationType.tanh: tanh,
|
||||
GRUActivationType.relu: relu,
|
||||
}
|
||||
|
||||
def set_inputs(self):
|
||||
batch_size = self.batch_size
|
||||
frame_size = self.frame_size
|
||||
self.op_type = 'gru_unit'
|
||||
self.inputs = {
|
||||
'Input': np.random.uniform(
|
||||
-0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'),
|
||||
'HiddenPrev': np.random.uniform(
|
||||
-0.1, 0.1, (batch_size, frame_size)).astype('float32'),
|
||||
'Weight': np.random.uniform(
|
||||
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
|
||||
(frame_size, frame_size * 3)).astype('float32'),
|
||||
}
|
||||
self.attrs = {
|
||||
'activation': GRUActivationType.tanh,
|
||||
'gate_activation': GRUActivationType.sigmoid
|
||||
}
|
||||
|
||||
def set_outputs(self):
|
||||
# GRU calculations
|
||||
batch_size = self.batch_size
|
||||
frame_size = self.frame_size
|
||||
x = self.inputs['Input']
|
||||
h_p = self.inputs['HiddenPrev']
|
||||
w = self.inputs['Weight']
|
||||
b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros(
|
||||
(1, frame_size * 3))
|
||||
g = x + np.tile(b, (batch_size, 1))
|
||||
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
|
||||
(frame_size, frame_size * 2))
|
||||
u_r = self.activate[self.attrs['gate_activation']](np.dot(
|
||||
h_p, w_u_r) + g[:, :frame_size * 2])
|
||||
u = u_r[:, :frame_size]
|
||||
r = u_r[:, frame_size:frame_size * 2]
|
||||
r_h_p = r * h_p
|
||||
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
|
||||
(frame_size, frame_size))
|
||||
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
|
||||
g[:, frame_size * 2:])
|
||||
g = np.hstack((u_r, c))
|
||||
h = u * h_p + (1 - u) * c
|
||||
self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h}
|
||||
|
||||
def setUp(self):
|
||||
self.set_inputs()
|
||||
self.set_outputs()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
['Input', 'HiddenPrev', 'Weight'], ['Hidden'],
|
||||
max_relative_error=0.007)
|
||||
|
||||
|
||||
class TestGRUUnitOpWithBias(TestGRUUnitOp):
|
||||
def set_inputs(self):
|
||||
batch_size = self.batch_size
|
||||
frame_size = self.frame_size
|
||||
super(TestGRUUnitOpWithBias, self).set_inputs()
|
||||
self.inputs['Bias'] = np.random.uniform(
|
||||
-0.1, 0.1, (1, frame_size * 3)).astype('float32')
|
||||
self.attrs = {
|
||||
'activation': GRUActivationType.identity,
|
||||
'gate_activation': GRUActivationType.sigmoid
|
||||
}
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
|
||||
max_relative_error=0.007)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue