|
|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/attention_lstm_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/framework/shape_runtime_infer.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
|
@ -23,29 +24,60 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
|
"Input(C0) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
|
|
|
|
|
"Input(LSTMWeight) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
|
|
|
|
|
"Input(LSTMBias) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
|
|
|
|
|
"Input(AttentionWeight) of AttentionLSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Output(Hidden) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Output(Cell) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
|
|
|
|
|
"Output(AttentionedX) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
|
|
|
|
|
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
|
|
|
|
|
"Output(LSTMX) of AttentionLSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
|
|
|
|
|
"Output(LSTMOUT) of AttentionLSTM should not be null.");
|
|
|
|
|
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
|
|
|
|
|
if (runtime_ctx == nullptr) {
|
|
|
|
|
LOG(FATAL) << "Should have runtime infer context";
|
|
|
|
|
}
|
|
|
|
|
const auto& ins = runtime_ctx->OpBase().Inputs();
|
|
|
|
|
const auto& outs = runtime_ctx->OpBase().Outputs();
|
|
|
|
|
const auto& scope = runtime_ctx->InferScope();
|
|
|
|
|
const auto ins_end = ins.end();
|
|
|
|
|
const auto outs_end = outs.end();
|
|
|
|
|
auto fair_input = [&](const std::string& name) -> bool {
|
|
|
|
|
auto it = ins.find(name);
|
|
|
|
|
if (it == ins_end) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const auto& in = it->second;
|
|
|
|
|
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return scope.FindVar(in[0]) != nullptr;
|
|
|
|
|
};
|
|
|
|
|
auto fair_output = [&](const std::string& name) -> bool {
|
|
|
|
|
auto it = outs.find(name);
|
|
|
|
|
if (it == outs_end) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const auto& out = it->second;
|
|
|
|
|
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return scope.FindVar(out[0]) != nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_input("C0"),
|
|
|
|
|
"Assert only one Input(C0) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_input("LSTMWeight"),
|
|
|
|
|
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_input("LSTMBias"),
|
|
|
|
|
"Assert only one Input(LSTMBias) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_input("AttentionWeight"),
|
|
|
|
|
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fair_output("Hidden"),
|
|
|
|
|
"Assert only one Output(Hidden) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("Cell"),
|
|
|
|
|
"Assert only one Output(Cell) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("AttentionedX"),
|
|
|
|
|
"Assert only one Output(AttentionedX) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("AttentionFCOut"),
|
|
|
|
|
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("LSTMX"),
|
|
|
|
|
"Assert only one Output(LSTMX) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(fair_output("LSTMOUT"),
|
|
|
|
|
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int M = x_dims[1];
|
|
|
|
|
@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
if (fair_input("H0")) {
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
if (ctx->HasInput("AttentionBias")) {
|
|
|
|
|
if (fair_input("AttentionBias")) {
|
|
|
|
|
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
|
|
|
|
|
"Input(AttentionBias)'s rank must be 2.");
|
|
|
|
|
@ -89,7 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionScalar")) {
|
|
|
|
|
if (fair_input("AttentionScalar")) {
|
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalar");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
|
"Input(AttentionScalar)'s rank must be 2.");
|
|
|
|
|
@ -97,10 +129,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionScalarBias")) {
|
|
|
|
|
if (fair_input("AttentionScalarBias")) {
|
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalarBias");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("AttentionScalar"),
|
|
|
|
|
fair_input("AttentionScalar"),
|
|
|
|
|
"AttentionScalar should not be null when have AttentionScalarBias.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
|
"Input(AttentionScalarBias)'s rank must be 2.");
|
|
|
|
|
|