|
|
|
@ -31,44 +31,58 @@ class GRUOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(%s) of GRUOp should not be null.", "Input");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(%s) of GRUOp should not be null.", "Weight");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
|
|
|
|
|
"Output(%s) of GRUOp should not be null.", "BatchGate");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
|
|
|
|
|
"Output(%s) of GRUOp should not be null.",
|
|
|
|
|
"BatchResetHiddenPrev");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
|
|
|
|
|
"Output(%s) of GRUOp should not be null.", "BatchHidden");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Output(%s) of GRUOp should not be null.", "Hidden");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
|
|
|
|
|
"BatchResetHiddenPrev", "GRU");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
|
|
|
|
|
"GRU");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
int input_size = input_dims[1];
|
|
|
|
|
int frame_size = weight_dims[0];
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_size, frame_size * 3,
|
|
|
|
|
"The input_size must be 3 times of frame_size in GRUOp.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Input) must be 3 "
|
|
|
|
|
"times of frame_size in GRUOp, but received %d "
|
|
|
|
|
"(Input) vs %d (frame_size).",
|
|
|
|
|
input_size, frame_size));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_dims[1], frame_size * 3,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
|
|
|
|
|
weight_dims[0], weight_dims[1], frame_size, frame_size * 3));
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h0_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
|
|
|
|
|
"The width of H0 must be equal to frame_size.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
h0_dims[1], frame_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The width of Input(H0) must be equal to frame_size, but "
|
|
|
|
|
"received %d (width of H0) vs %d (frame_size).",
|
|
|
|
|
h0_dims[1], frame_size));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
auto bias_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
int bias_height = bias_dims[0];
|
|
|
|
|
int bias_width = bias_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_height, 1,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_height, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_width, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("BatchGate", input_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
|
|
|
|
@ -166,39 +180,50 @@ class GRUGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(%s) of GRUGradOp should not be null.", "Input");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(%s) of GRUGradOp should not be null.", "Weight");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
|
|
|
|
|
"Input(%s) of GRUGradOp should not be null.", "BatchGate");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"),
|
|
|
|
|
"Input(%s) of GRUGradOp should not be null.",
|
|
|
|
|
"BatchResetHiddenPrev");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchHidden"),
|
|
|
|
|
"Input(%s) of GRUOp should not be null.", "BatchHidden");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
|
|
|
|
|
"Input(%s) of GRUGradOp should not be null.", "Hidden");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
|
|
|
|
"Input(%s@GRAD) of GRUGradOp should not be null.", "Hidden");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
|
|
|
|
|
"GRU@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchResetHiddenPrev"), "Input",
|
|
|
|
|
"BatchResetHiddenPrev", "GRU@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchHidden"), "Input", "BatchHidden",
|
|
|
|
|
"GRU@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "GRU@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
|
|
|
|
|
framework::GradVarName("Hidden"), "GRU@Grad");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
int input_size = input_dims[1];
|
|
|
|
|
int frame_size = weight_dims[0];
|
|
|
|
|
int weight_height = weight_dims[0];
|
|
|
|
|
int weight_width = weight_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
|
|
|
|
|
"The input_size must be 3 times of frame_size in GRUOp.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_size, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Input) must be 3 times of "
|
|
|
|
|
"frame_size in GRUOp, but received %d (Input) vs %d (frame_size).",
|
|
|
|
|
input_size, frame_size));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_height, frame_size,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
|
|
|
|
|
weight_height, weight_width, frame_size, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_width, frame_size * 3,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
|
|
|
|
|
weight_height, weight_width, frame_size, frame_size * 3));
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h0_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
|
|
|
|
|
"The width of H0 must be equal to frame_size.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
h0_dims[1], frame_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The width of Input(H0) must be equal to frame_size, but "
|
|
|
|
|
"received %d (width of H0) vs %d (frame_size).",
|
|
|
|
|
h0_dims[1], frame_size));
|
|
|
|
|
auto h0_grad_name = framework::GradVarName("H0");
|
|
|
|
|
if (ctx->HasOutput(h0_grad_name))
|
|
|
|
|
ctx->SetOutputDim(h0_grad_name, h0_dims);
|
|
|
|
@ -207,10 +232,18 @@ class GRUGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto bias_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
int bias_height = bias_dims[0];
|
|
|
|
|
int bias_width = bias_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_height, 1,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_height, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_width, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
auto bias_grad_name = framework::GradVarName("Bias");
|
|
|
|
|
if (ctx->HasOutput(bias_grad_name))
|
|
|
|
|
ctx->SetOutputDim(bias_grad_name, bias_dims);
|
|
|
|
@ -298,14 +331,20 @@ class GRUCPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
|
|
|
|
|
frame_size * 2 /*width of weight*/,
|
|
|
|
|
frame_size /*height of height*/);
|
|
|
|
|
PADDLE_ENFORCE(packed_gate);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
packed_gate, platform::errors::NotFound(
|
|
|
|
|
"The caculation result of packed_gate by "
|
|
|
|
|
"GEMM_ALLOC should not be null when using MKL."));
|
|
|
|
|
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
|
|
|
|
|
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
|
|
|
|
|
packed_gate);
|
|
|
|
|
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
|
|
|
|
|
frame_size /*width of weight*/,
|
|
|
|
|
frame_size /*height of height*/);
|
|
|
|
|
PADDLE_ENFORCE(packed_state);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
packed_state, platform::errors::NotFound(
|
|
|
|
|
"The caculation result of packed_state by "
|
|
|
|
|
"GEMM_ALLOC should not be null when using MKL."));
|
|
|
|
|
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
|
|
|
|
|
frame_size, T(1.0), gru_value.state_weight, frame_size,
|
|
|
|
|
packed_state);
|
|
|
|
|