|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/fusion_gru_op.h"
|
|
|
|
|
#include <cstring> // for memcpy
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/framework/shape_runtime_infer.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
@ -25,14 +26,46 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
|
|
|
|
|
"Input(WeightX) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
|
|
|
|
|
"Input(WeightH) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Output(Hidden) of GRU should not be null.");
|
|
|
|
|
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
|
|
|
|
|
if (runtime_ctx == nullptr) {
|
|
|
|
|
LOG(FATAL) << "Should have runtime infer context";
|
|
|
|
|
}
|
|
|
|
|
const auto& ins = runtime_ctx->OpBase().Inputs();
|
|
|
|
|
const auto& outs = runtime_ctx->OpBase().Outputs();
|
|
|
|
|
const auto& scope = runtime_ctx->InferScope();
|
|
|
|
|
const auto ins_end = ins.end();
|
|
|
|
|
const auto outs_end = outs.end();
|
|
|
|
|
auto fair_input = [&](const std::string& name) -> bool {
|
|
|
|
|
auto it = ins.find(name);
|
|
|
|
|
if (it == ins_end) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const auto& in = it->second;
|
|
|
|
|
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return scope.FindVar(in[0]) != nullptr;
|
|
|
|
|
};
|
|
|
|
|
auto fair_output = [&](const std::string& name) -> bool {
|
|
|
|
|
auto it = outs.find(name);
|
|
|
|
|
if (it == outs_end) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const auto& out = it->second;
|
|
|
|
|
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return scope.FindVar(out[0]) != nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(fair_input("WeightX"),
|
|
|
|
|
"Assert only one Input(WeightX) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(fair_input("WeightH"),
|
|
|
|
|
"Assert only one Input(WeightH) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("Hidden"),
|
|
|
|
|
"Assert only one Output(Hidden) of GRU.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
@ -58,12 +91,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"should be 3 * %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
if (fair_input("H0")) {
|
|
|
|
|
auto h0_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
|
|
|
|
|
"The width of H0 must be equal to frame_size.");
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
if (fair_input("Bias")) {
|
|
|
|
|
auto b_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
@ -79,12 +112,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
xx_width = wx_dims[1];
|
|
|
|
|
} else {
|
|
|
|
|
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
|
|
|
|
|
"Output(ReorderedH0) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
|
|
|
|
|
"Output(BatchedInput) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
|
|
|
|
|
"Output(BatchedOut) of GRU should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("ReorderedH0"),
|
|
|
|
|
"Assert only one Output(ReorderedH0) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("BatchedInput"),
|
|
|
|
|
"Assert only one Output(BatchedInput) of GRU.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("BatchedOut"),
|
|
|
|
|
"Assert only one Output(BatchedOut) of GRU.");
|
|
|
|
|
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("BatchedOut", out_dims);
|
|
|
|
|
}
|
|
|
|
|