|
|
|
@ -23,7 +23,7 @@ namespace operators {
|
|
|
|
|
void FusionSeqPoolConcatOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
"Inputs(X) of FusionSeqPoolConcatOp should be empty.");
|
|
|
|
|
"Inputs(X) of FusionSeqPoolConcatOp should not be empty.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSeqPoolConcatOp should not be null.");
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
@ -54,12 +54,13 @@ void FusionSeqPoolConcatOpMaker::Make() {
|
|
|
|
|
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
|
|
|
|
|
AddAttr<std::string>("pooltype",
|
|
|
|
|
"(string, default 'AVERAGE') some of the pooling "
|
|
|
|
|
"(string, default 'SUM') some of the pooling "
|
|
|
|
|
"pooltype of SequencePoolOp.")
|
|
|
|
|
.SetDefault("SUM")
|
|
|
|
|
.InEnum({"AVERAGE", "SUM", "SQRT"});
|
|
|
|
|
AddAttr<int>("axis",
|
|
|
|
|
"The axis along which the input tensors will be concatenated.")
|
|
|
|
|
"The axis along which the input tensors will be concatenated. "
|
|
|
|
|
"Only supports concat axis=1 yet.")
|
|
|
|
|
.SetDefault(1);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
|
|
|
|
@ -100,6 +101,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
|
|
|
|
|
attr);
|
|
|
|
|
size_t n = ins.size();
|
|
|
|
|
size_t dst_step_size = n * w;
|
|
|
|
|
for (size_t i = 0; i < n; ++i) {
|
|
|
|
|
auto x_dims = ins[i]->dims();
|
|
|
|
|
auto x_lod = ins[i]->lod()[0];
|
|
|
|
@ -112,7 +114,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t j = 0; j < bs; ++j) {
|
|
|
|
|
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
|
|
|
|
|
seqpool(src, dst, &attr);
|
|
|
|
|
dst += n * w;
|
|
|
|
|
dst += dst_step_size;
|
|
|
|
|
src += attr.h * attr.w;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|