|
|
|
@ -24,38 +24,59 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionSeqExpandConcatFCOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("FCWeight"),
|
|
|
|
|
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("FCOut"),
|
|
|
|
|
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Inputs(X) of FusionSeqExpandConcatFCOp should larger "
|
|
|
|
|
"than 1, but received value is: %d.",
|
|
|
|
|
ctx->Inputs("X").size()));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("FCWeight"), "Input", "FCWeight",
|
|
|
|
|
"fusion_seqexpand_concat_fc");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"fusion_seqexpand_concat_fc");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("FCOut"), "Output", "FCOut",
|
|
|
|
|
"fusion_seqexpand_concat_fc");
|
|
|
|
|
|
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
|
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(FCWeight)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(FCWeight)'s rank must be 2, but received value is: %d.",
|
|
|
|
|
w_dims.size()));
|
|
|
|
|
const int D = w_dims[1];
|
|
|
|
|
int sum = ins_dims[0][1];
|
|
|
|
|
for (size_t i = 1; i < ins_dims.size(); ++i) {
|
|
|
|
|
sum += ins_dims[i][1];
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(sum, w_dims[0],
|
|
|
|
|
"FC height should be sum of all inputs width.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(sum, w_dims[0], platform::errors::InvalidArgument(
|
|
|
|
|
"FC height should be sum of all inputs "
|
|
|
|
|
"width, but received FC height is: %d, "
|
|
|
|
|
"sum of all inputs width is: %d.",
|
|
|
|
|
w_dims[0], sum));
|
|
|
|
|
if (ctx->HasInput("FCBias")) {
|
|
|
|
|
auto b_dims = ctx->GetInputDim("FCBias");
|
|
|
|
|
PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2,
|
|
|
|
|
"b_dims should be 1 or 2, get %d", b_dims.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims.size() == 1 || b_dims.size() == 2, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"FCBias dim should be 1 or 2, but received value is: %d.",
|
|
|
|
|
b_dims.size()));
|
|
|
|
|
if (b_dims.size() == 1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], D,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"FCBias shapes must be %d when FCBias dim = 1, but "
|
|
|
|
|
"received value is: %d.",
|
|
|
|
|
D, b_dims[0]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"FCBias shapes must be 1x%d, when FCBias dim = 2, "
|
|
|
|
|
"but received dim[0] is: %d.",
|
|
|
|
|
D, b_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], D,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"FCBias shapes must be 1x%d, when FCBias dim = 2, "
|
|
|
|
|
"but received dim[1] is: %d.",
|
|
|
|
|
D, b_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -133,18 +154,42 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// some check and fcout should be reshape here
|
|
|
|
|
// since infershape can not get lod info
|
|
|
|
|
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ref_lod.size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only support input lod size is 1, but received value is: %d.",
|
|
|
|
|
ref_lod.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in1_lod.size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only support input lod size is 1, but received value is: %d.",
|
|
|
|
|
in1_lod.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0].size() - 1), N,
|
|
|
|
|
"Batch size of all inputs should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0][N]), N,
|
|
|
|
|
"Seq_length of other inputs should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Batch size of all inputs should be equal to %d, but "
|
|
|
|
|
"received value is: %d.",
|
|
|
|
|
N, static_cast<int>(in1_lod[0].size() - 1)));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
static_cast<int>(in1_lod[0][N]), N,
|
|
|
|
|
platform::errors::InvalidArgument("Seq_length of other inputs should "
|
|
|
|
|
"be %d, but received value is: %d.",
|
|
|
|
|
N, static_cast<int>(in1_lod[0][N])));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in1_dims[0], N,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input height should be batch size: %d, but received value is %d.",
|
|
|
|
|
N, in1_dims[0]));
|
|
|
|
|
for (size_t i = 2; i < ins.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
|
|
|
|
|
"All other inputs height should be equal");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"All other inputs height should be equal to %d, "
|
|
|
|
|
"but received value is: %d.",
|
|
|
|
|
N, ins[i]->dims()[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
|
|
|
|
|
"All other inputs should have same lod");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"All other inputs should have same lod: %d, but "
|
|
|
|
|
"received value is: %d.",
|
|
|
|
|
in1_lod, ins[i]->lod()));
|
|
|
|
|
}
|
|
|
|
|
fc_out->Resize({N, D});
|
|
|
|
|
|
|
|
|
|