|
|
|
@ -23,16 +23,23 @@ namespace operators {
|
|
|
|
|
void FusionSeqPoolConcatOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
"Inputs(X) of FusionSeqPoolConcatOp should not be empty.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSeqPoolConcatOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Inputs(X) of FusionSeqPoolConcatOp should be greated "
|
|
|
|
|
"than 1, but received value is %d.",
|
|
|
|
|
ctx->Inputs("X").size()));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FusionSeqPoolConcat");
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
PADDLE_ENFORCE_EQ(axis, 1,
|
|
|
|
|
"FusionSeqPoolConcatOp only supports concat axis=1 yet.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(axis, 1, platform::errors::InvalidArgument(
|
|
|
|
|
"FusionSeqPoolConcatOp only supports concat "
|
|
|
|
|
"axis=1 yet, but received axis value is %d",
|
|
|
|
|
axis));
|
|
|
|
|
|
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
|
const size_t n = ins_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0.");
|
|
|
|
|
PADDLE_ENFORCE_GT(n, 0UL, platform::errors::InvalidArgument(
|
|
|
|
|
"Input tensors count should be greater than 0, "
|
|
|
|
|
"but received value is %d.",
|
|
|
|
|
n));
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
LOG(WARNING) << "Only have one input, may waste memory";
|
|
|
|
|
}
|
|
|
|
@ -40,7 +47,10 @@ void FusionSeqPoolConcatOp::InferShape(
|
|
|
|
|
// The output height should be confirmed in Compute,
|
|
|
|
|
// since input lod is not accessible here.
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2,
|
|
|
|
|
"The dims size of first input should be 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dims size of first input should be equal to 2, "
|
|
|
|
|
"but received value is %d.",
|
|
|
|
|
ins_dims[0].size()));
|
|
|
|
|
ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
|
|
|
|
|
|
|
|
|
|
if (!ctx->IsRuntime()) {
|
|
|
|
@ -96,7 +106,10 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int w = ins[0]->numel() / x0_dims[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
|
|
|
|
|
"The output of dims[1] should be dividable of w");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The output of dims[1] should be dividable of w, but "
|
|
|
|
|
"dims[1] is %d, w is %d.",
|
|
|
|
|
y_dims[1], w));
|
|
|
|
|
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
|
|
|
|
|
if (pooltype == "AVERAGE") {
|
|
|
|
|
attr.type = jit::SeqPoolType::kAvg;
|
|
|
|
@ -113,10 +126,18 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto x_lod = ins[i]->lod()[0];
|
|
|
|
|
const T* src = ins[i]->data<T>();
|
|
|
|
|
T* dst = y_data + i * w;
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<int>(ins[i]->numel() / x_dims[0]), w,
|
|
|
|
|
"Width of all inputs should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1,
|
|
|
|
|
"Batchsize of all inputs should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
static_cast<int>(ins[i]->numel() / x_dims[0]), w,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Width of all inputs should be equal, but the width of the %d-th "
|
|
|
|
|
"input %d is not equal to the previous %d",
|
|
|
|
|
i, static_cast<int>(ins[i]->numel() / x_dims[0]), w));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_lod.size(), bs + 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Batchsize of all inputs should be equal, but the value of the "
|
|
|
|
|
"%d-th %d is not equal to the previous %d.",
|
|
|
|
|
i, x_lod.size(), bs + 1));
|
|
|
|
|
for (size_t j = 0; j < bs; ++j) {
|
|
|
|
|
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
|
|
|
|
|
seqpool(src, dst, &attr);
|
|
|
|
|