|
|
|
@ -23,91 +23,36 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
"Inputs(X) of FusionSeqConcatFCOp should larger than 1.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("FCWeight"),
|
|
|
|
|
"Input(FCWeight) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("FCOut"),
|
|
|
|
|
"Output(FCOut) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
|
|
|
|
|
// need check fc height = all inputs width sum
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int M = x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
|
|
|
|
|
auto w_dims = ctx->GetInputDim("LSTMWeight");
|
|
|
|
|
const int D = w_dims[1] / 4;
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
|
|
|
|
|
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
|
|
|
|
|
|
|
|
|
|
auto b_dims = ctx->GetInputDim("LSTMBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
|
|
|
|
|
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
|
|
|
|
|
"Input(AttentionWeight)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
if (ctx->HasInput("AttentionBias")) {
|
|
|
|
|
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
|
|
|
|
|
"Input(AttentionBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
|
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
|
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
|
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2.");
|
|
|
|
|
const int D = w_dims[1];
|
|
|
|
|
int sum = ins_dims[0][1];
|
|
|
|
|
for (size_t i = 1; i < ins_dims.size(); ++i) {
|
|
|
|
|
sum += ins_dims[i][1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionScalar")) {
|
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalar");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
|
"Input(AttentionScalar)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(sum, w_dims[0],
|
|
|
|
|
"FC height should be sum of all inputs width.");
|
|
|
|
|
if (ctx->HasInput("FCBias")) {
|
|
|
|
|
auto b_dims = ctx->GetInputDim("FCBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionScalarBias")) {
|
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalarBias");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("AttentionScalar"),
|
|
|
|
|
"AttentionScalar should not be null when have AttentionScalarBias.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
|
"Input(AttentionScalarBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims({x_dims[0], D});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("Cell", out_dims);
|
|
|
|
|
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
|
|
|
|
|
ctx->SetOutputDim("LSTMX", {1, M});
|
|
|
|
|
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
|
|
|
|
|
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
|
|
|
|
|
ctx->ShareLoD("X", "Hidden");
|
|
|
|
|
ctx->ShareLoD("X", "Cell");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
ctx->SetOutputDim("Out", {ins_dims[0][0], D});
|
|
|
|
|
// fcout should be reshape when run since can not get lod in infershape
|
|
|
|
|
// explicit share the ref lod
|
|
|
|
|
ctx->ShareLoD("X", "Out", 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType(
|
|
|
|
@ -154,46 +99,46 @@ The concat axis should be 1.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
|
|
|
|
|
if (bias) {
|
|
|
|
|
math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
|
|
|
|
|
math::vec_relu<T, platform::jit::avx>(n, y, y);
|
|
|
|
|
} else {
|
|
|
|
|
math::vec_relu<T, platform::jit::avx>(n, x, y);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void vec_softmax(const int n, const T* x, T* y) {
|
|
|
|
|
T scalar = x[0];
|
|
|
|
|
// max
|
|
|
|
|
for (int i = 1; i < n; ++i) {
|
|
|
|
|
scalar = scalar < x[i] ? x[i] : scalar;
|
|
|
|
|
}
|
|
|
|
|
math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub
|
|
|
|
|
math::vec_exp<T>(n, y, y); // exp
|
|
|
|
|
// sum
|
|
|
|
|
scalar = T(0);
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
scalar += y[i];
|
|
|
|
|
}
|
|
|
|
|
math::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* ins = ctx.Input<LoDTensor>("X");
|
|
|
|
|
auto ins = ctx.MultiInput<LoDTensor>("X");
|
|
|
|
|
auto* w = ctx.Input<Tensor>("FCWeight");
|
|
|
|
|
auto* b = ctx.Input<Tensor>("FCBias");
|
|
|
|
|
|
|
|
|
|
auto* out = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>("FCOUT");
|
|
|
|
|
|
|
|
|
|
auto* ref_in = ins[0];
|
|
|
|
|
auto ref_lod = ref_in->lod();
|
|
|
|
|
auto in1_lod = ins[1]->lod();
|
|
|
|
|
auto ref_dims = ref_in->dims(); // T x M0
|
|
|
|
|
auto in1_dims = ins[1]->dims(); // N x M1
|
|
|
|
|
auto w_dims = w->dims();
|
|
|
|
|
const int N = ref_lod[0].size() - 1;
|
|
|
|
|
const int total_T = ref_dims[0];
|
|
|
|
|
const int M0 = ref_dims[1];
|
|
|
|
|
const int M1 = in1_dims[1];
|
|
|
|
|
const int D = w_dims[1];
|
|
|
|
|
|
|
|
|
|
// some check and fcout should be reshape here
|
|
|
|
|
// since infershape can not get lod info
|
|
|
|
|
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
|
|
|
|
|
"Batch size of all inputs should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
|
|
|
|
|
"Seq_length of other inputs should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
|
|
|
|
|
for (size_t i = 2; i < ins.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
|
|
|
|
|
"All other inputs height should be equal");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
|
|
|
|
|
"All other inputs should have same lod");
|
|
|
|
|
}
|
|
|
|
|
fc_out->Resize({N, D});
|
|
|
|
|
|
|
|
|
|
std::function<void(const int, const T*, T*)> fc_act;
|
|
|
|
|
auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) {
|
|
|
|
@ -204,19 +149,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
fc_act = act_functor(fc_act_str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1.");
|
|
|
|
|
auto* ref_in = ins[0];
|
|
|
|
|
auto ref_in_lod = ref_in->lod();
|
|
|
|
|
const int N = ref_in_lod[0].size() - 1;
|
|
|
|
|
auto ref_in_dims = ref_in->dims(); // T x M0
|
|
|
|
|
auto w_dims = w->dims(); // (M0+M1+M2+..) x D
|
|
|
|
|
const int total_T = ref_in_dims[0];
|
|
|
|
|
const int M0 = ref_in_dims[1];
|
|
|
|
|
const int M1 = ins[1]->dims()[1];
|
|
|
|
|
const int D = w_dims[1];
|
|
|
|
|
|
|
|
|
|
const T* ref_in_data =
|
|
|
|
|
ref_in->data<T>(); // size should be check at infershape
|
|
|
|
|
const T* ref_in_data = ref_in->data<T>();
|
|
|
|
|
const T* in1_data = ins[1]->data<T>();
|
|
|
|
|
const T* w_data = w->data<T>();
|
|
|
|
|
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -226,11 +159,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
|
|
|
|
|
out_data, b ? b->data<T>() : NULL);
|
|
|
|
|
w_data = w_data + M0 * D;
|
|
|
|
|
|
|
|
|
|
// first one use write on
|
|
|
|
|
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
|
|
|
|
|
w_data = w_data + M1 * D;
|
|
|
|
|
for (int i = 2; i < ins.size(); ++i) {
|
|
|
|
|
for (size_t i = 2; i < ins.size(); ++i) {
|
|
|
|
|
// add on
|
|
|
|
|
const T* in_data = ins[i]->data<T>();
|
|
|
|
|
const int K = ins[i]->dims()[1];
|
|
|
|
@ -240,7 +172,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i];
|
|
|
|
|
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
|
|
|
|
|
T* src = fc_out_data + i * D;
|
|
|
|
|
for (int step = 0; step < seq_len; ++step) {
|
|
|
|
|
blas.VADD(D, out_data, src, out_data);
|
|
|
|
@ -248,7 +180,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fc_act(out_dims[0] * out_dims[1], out_data, out_data);
|
|
|
|
|
fc_act(total_T * D, out_data, out_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|