|
|
|
@ -25,19 +25,14 @@ class GRUUnitOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(%s) of GRUUnitOp should not be null.", "Input");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
|
|
|
|
|
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(%s) of GRUUnitOp should not be null.", "Weight");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Gate"),
|
|
|
|
|
"Output(%s) of GRUUnitOp should not be null.", "Gate");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
|
|
|
|
|
"Output(%s) of GRUUnitOp should not be null.",
|
|
|
|
|
"ResetHiddenPrev");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Output(%s) of GRUUnitOp should not be null.", "Hidden");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
|
|
|
|
|
"GRUUnit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Gate"), "Output", "Gate", "GRUUnit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ResetHiddenPrev"), "Output",
|
|
|
|
|
"ResetHiddenPrev", "GRUUnit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRUUnit");
|
|
|
|
|
auto input_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
@ -46,23 +41,45 @@ class GRUUnitOp : public framework::OperatorWithKernel {
|
|
|
|
|
int frame_size = hidden_prev_dims[1];
|
|
|
|
|
int weight_height = weight_dims[0];
|
|
|
|
|
int weight_width = weight_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_size, frame_size * 3,
|
|
|
|
|
"The input_size must be 3 times of frame_size in GRUUnitOp.");
|
|
|
|
|
if (ctx->IsRuntime() || input_size >= 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Input) must be 3 "
|
|
|
|
|
"times of frame_size in GRUUnitOp, but received %d "
|
|
|
|
|
"(Input) vs %d (frame_size).",
|
|
|
|
|
input_size, frame_size));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_height, frame_size,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
|
|
|
|
|
"(frame_size).",
|
|
|
|
|
weight_height, weight_width, frame_size, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_width, frame_size * 3,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
|
|
|
|
|
"(frame_size).",
|
|
|
|
|
weight_height, weight_width, frame_size, frame_size * 3));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
auto bias_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
int bias_height = bias_dims[0];
|
|
|
|
|
int bias_width = bias_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_height, 1,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_height, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_width, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
|
|
|
|
|
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
|
|
|
|
@ -143,21 +160,16 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(%s) of GRUUnitGradOp should not be null.", "Input");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
|
|
|
|
|
"Input(%s) of GRUUnitGradOp should not be null.",
|
|
|
|
|
"HiddenPrev");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(%s) of GRUUnitGradOp should not be null.", "Weight");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Gate"),
|
|
|
|
|
"Input(%s) of GRUUnitGradOp should not be null.", "Gate");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
|
|
|
|
|
"Input(%s) of GRUUnitGradOp should not be null.",
|
|
|
|
|
"ResetHiddenPrev");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
|
|
|
|
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
|
|
|
|
|
"Hidden");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnitGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
|
|
|
|
|
"GRUUnitGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnitGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Gate"), "Input", "Gate", "GRUUnitGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("ResetHiddenPrev"), "Input", "ResetHiddenPrev",
|
|
|
|
|
"GRUUnitGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
|
|
|
|
|
"Hidden@GRAD", "GRUUnitGrad");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
|
|
|
|
|
auto weight_dims = ctx->GetInputDim("Weight");
|
|
|
|
@ -166,23 +178,46 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
int frame_size = hidden_prev_dims[1];
|
|
|
|
|
int weight_height = weight_dims[0];
|
|
|
|
|
int weight_width = weight_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_size, frame_size * 3,
|
|
|
|
|
"The input_size must be 3 times of frame_size in GRUUnitOp.");
|
|
|
|
|
if (ctx->IsRuntime() || input_size >= 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_size, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Input) must be 3 "
|
|
|
|
|
"times of frame_size in GRUUnitGradOp, but received %d "
|
|
|
|
|
"(Input) vs %d (frame_size).",
|
|
|
|
|
input_size, frame_size));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_height, frame_size,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
|
|
|
|
|
"(frame_size).",
|
|
|
|
|
weight_height, weight_width, frame_size, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_width, frame_size * 3,
|
|
|
|
|
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
|
|
|
|
|
"* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
|
|
|
|
|
"(frame_size).",
|
|
|
|
|
weight_height, weight_width, frame_size, frame_size * 3));
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
auto bias_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
int bias_height = bias_dims[0];
|
|
|
|
|
int bias_width = bias_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_height, 1,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_height, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
bias_width, frame_size * 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but received "
|
|
|
|
|
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
|
|
|
|
|
bias_height, bias_width, frame_size * 3));
|
|
|
|
|
auto bias_grad_name = framework::GradVarName("Bias");
|
|
|
|
|
if (ctx->HasOutput(bias_grad_name))
|
|
|
|
|
ctx->SetOutputDim(bias_grad_name, bias_dims);
|
|
|
|
|