|
|
|
@ -24,51 +24,80 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
|
|
|
|
|
"Assert only one Input(WeightX) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
|
|
|
|
|
"Assert only one Input(WeightH) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Assert only one Output(Hidden) of GRU.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank must be 2, but received input dim "
|
|
|
|
|
"size is:%d, input dim is:[%s]",
|
|
|
|
|
x_dims.size(), x_dims));
|
|
|
|
|
|
|
|
|
|
auto wx_dims = ctx->GetInputDim("WeightX");
|
|
|
|
|
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
|
|
|
|
|
"The rank of Input(WeightX) should be 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(WeightX) should be 2, but received "
|
|
|
|
|
"WeightX dim size is:%d, WeightX dim is:[%s] ",
|
|
|
|
|
wx_dims.size(), wx_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(WeightX) "
|
|
|
|
|
"should be %d.",
|
|
|
|
|
x_dims[1]);
|
|
|
|
|
"should equal to second dimension of input x, but "
|
|
|
|
|
"received WeightX dimension is:%d, x dimension is:%d",
|
|
|
|
|
wx_dims[0], x_dims[1]));
|
|
|
|
|
|
|
|
|
|
int frame_size = wx_dims[1] / 3;
|
|
|
|
|
auto wh_dims = ctx->GetInputDim("WeightH");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
|
|
|
|
|
"The rank of Input(WeightH) should be 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(WeightH) should be 2, but received "
|
|
|
|
|
"WeightH dim size is:%d, WeightH dim is:[%s]",
|
|
|
|
|
wh_dims.size(), wh_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
|
|
|
|
|
"The first dimension of Input(WeightH) "
|
|
|
|
|
"should be %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of WeightH "
|
|
|
|
|
"should equal to frame_size, but received WeightH's "
|
|
|
|
|
"first dimension is: "
|
|
|
|
|
"%d, frame size is:%d",
|
|
|
|
|
wh_dims[0], frame_size));
|
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(WeightH) "
|
|
|
|
|
"should be 3 * %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
"should equal to 3 * frame_size, but received WeightH "
|
|
|
|
|
"is:%d, frame size is:%d",
|
|
|
|
|
wh_dims[1], frame_size));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h0_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
|
|
|
|
|
"The width of H0 must be equal to frame_size.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The width of H0 must be equal to frame_size, but "
|
|
|
|
|
"receiced the width of H0 is:%d, frame size is:%d",
|
|
|
|
|
h0_dims[1], frame_size));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
auto b_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(Bias) should be 2, but received "
|
|
|
|
|
"Bias rank is:%d, Bias dim is:[%s]",
|
|
|
|
|
b_dims.size(), b_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
|
"The first dimension of Input(Bias) should be 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(Bias) should be 1, but "
|
|
|
|
|
"received Bias first dim is:%d, Bias dim is:[%s]",
|
|
|
|
|
b_dims[0], b_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Bias must be [1, frame_size * 3], but "
|
|
|
|
|
"received bias dim is:[%s], frame size is:%d",
|
|
|
|
|
b_dims, frame_size));
|
|
|
|
|
}
|
|
|
|
|
framework::DDim out_dims({x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
@ -78,12 +107,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
xx_width = wx_dims[1];
|
|
|
|
|
} else {
|
|
|
|
|
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Assert only one Output(ReorderedH0) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Assert only one Output(BatchedInput) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
|
|
|
|
|
"Assert only one Output(BatchedOut) of GRU.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
|
|
|
|
|
"fusion_gru");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
|
|
|
|
|
"fusion_gru");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
|
|
|
|
|
"fusion_gru");
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedOut", out_dims);
|
|
|
|
|
}
|
|
|
|
|