parent
102d93712e
commit
7923d7271f
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_seqpool_concat_op.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void FusionSeqPoolConcatOp::InferShape(
|
||||
framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
|
||||
"Inputs(X) of FusionSeqPoolConcatOp should be empty.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FusionSeqPoolConcatOp should not be null.");
|
||||
int axis = ctx->Attrs().Get<int>("axis");
|
||||
PADDLE_ENFORCE_EQ(axis, 1,
|
||||
"FusionSeqPoolConcatOp only supports concat axis=1 yet.");
|
||||
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<std::string>("pooltype"), "SUM",
|
||||
"FusionSeqPoolConcatOp only supports sum pool type yet.");
|
||||
|
||||
auto ins_dims = ctx->GetInputsDim("X");
|
||||
const size_t n = ins_dims.size();
|
||||
PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0.");
|
||||
if (n == 1) {
|
||||
LOG(WARNING) << "Only have one input, may waste memory";
|
||||
}
|
||||
|
||||
// The output height should be confirmed in Compute,
|
||||
// since input lod is not accessible here.
|
||||
PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2UL,
|
||||
"The dims size of first input should be 2.");
|
||||
ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
|
||||
}
|
||||
|
||||
framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
return framework::OpKernelType(
|
||||
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
|
||||
}
|
||||
|
||||
void FusionSeqPoolConcatOpMaker::Make() {
|
||||
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
|
||||
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
|
||||
AddAttr<std::string>("pooltype",
|
||||
"(string, default 'AVERAGE') some of the pooling "
|
||||
"pooltype of SequencePoolOp.")
|
||||
.SetDefault("SUM")
|
||||
.InEnum({"AVERAGE", "SUM", "SQRT"});
|
||||
AddAttr<int>("axis",
|
||||
"The axis along which the input tensors will be concatenated.")
|
||||
.SetDefault(1);
|
||||
AddComment(R"DOC(
|
||||
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X");
|
||||
auto* out = ctx.Output<LoDTensor>("Out");
|
||||
auto x0_lod = ins[0]->lod();
|
||||
auto x0_dims = ins[0]->dims();
|
||||
auto y_dims = out->dims();
|
||||
size_t bs = x0_lod[0].size() - 1;
|
||||
out->Resize({static_cast<int64_t>(bs), y_dims[1]});
|
||||
framework::LoD y_lod(1);
|
||||
y_lod[0].resize(bs + 1);
|
||||
for (size_t i = 0; i <= bs; ++i) {
|
||||
y_lod[0][i] = i;
|
||||
}
|
||||
out->set_lod(y_lod);
|
||||
auto place = ctx.GetPlace();
|
||||
T* y_data = out->mutable_data<T>(place);
|
||||
|
||||
int w = ins[0]->numel() / x0_dims[0];
|
||||
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
|
||||
"The output of dims[1] should be dividable of w");
|
||||
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
|
||||
auto seqpool =
|
||||
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
|
||||
attr);
|
||||
size_t n = ins.size();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
auto x_dims = ins[i]->dims();
|
||||
auto x_lod = ins[i]->lod()[0];
|
||||
const T* src = ins[i]->data<T>();
|
||||
T* dst = y_data + i * w;
|
||||
PADDLE_ENFORCE_EQ(static_cast<int>(ins[i]->numel() / x_dims[0]), w,
|
||||
"Width of all inputs should be equal.");
|
||||
PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1,
|
||||
"Batchsize of all inputs should be equal.");
|
||||
for (size_t j = 0; j < bs; ++j) {
|
||||
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
|
||||
seqpool(src, dst, &attr);
|
||||
dst += n * w;
|
||||
src += attr.h * attr.w;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fusion_seqpool_concat, ops::FusionSeqPoolConcatOp,
|
||||
ops::FusionSeqPoolConcatOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(fusion_seqpool_concat,
|
||||
ops::FusionSeqPoolConcatKernel<float>,
|
||||
ops::FusionSeqPoolConcatKernel<double>);
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class FusionSeqPoolConcatOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
class FusionSeqPoolConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue