rename fusion seq_concat_fc to fusion seqexpand_concat_fc

createGenDocLib
tensor-tang 7 years ago
parent 0f0d48230c
commit 02909335e9

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h" #include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
@ -22,15 +22,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { void FusionSeqExpandConcatFCOp::InferShape(
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL, framework::InferShapeContext* ctx) const {
"Inputs(X) of FusionSeqConcatFCOp should larger than 1."); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE(ctx->HasInput("FCWeight"), ctx->Inputs("X").size(), 1UL,
"Input(FCWeight) of FusionSeqConcatFC should not be null."); "Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(
"Output(Out) of FusionSeqConcatFC should not be null."); ctx->HasInput("FCWeight"),
PADDLE_ENFORCE(ctx->HasOutput("FCOut"), "Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
"Output(FCOut) of FusionSeqConcatFC should not be null."); PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X"); auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Out", 0); ctx->ShareLoD("X", "Out", 0);
} }
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
void FusionSeqConcatFCOpMaker::Make() { void FusionSeqExpandConcatFCOpMaker::Make() {
AddInput("X", AddInput("X",
"(LoDTensor) input LodDTensors, the first one must be have ref lod " "(LoDTensor) input LodDTensors, the first one must be have ref lod "
"for sequence expand, and the rest input should have same lod.") "for sequence expand, and the rest input should have same lod.")
@ -100,7 +105,7 @@ The concat axis should be 1.
} }
template <typename T> template <typename T>
class FusionSeqConcatFCKernel : public framework::OpKernel<T> { class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp, REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqConcatFCOpMaker, ops::FusionSeqExpandConcatFCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc, REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqConcatFCKernel<float>, ops::FusionSeqExpandConcatFCOpKernel<float>,
ops::FusionSeqConcatFCKernel<double>); ops::FusionSeqExpandConcatFCOpKernel<double>);

@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class FusionSeqConcatFCOp : public framework::OperatorWithKernel { class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker { class FusionSeqExpandConcatFCOpMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override; void Make() override;
}; };

@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest):
pass pass
def setUp(self): def setUp(self):
self.op_type = 'fusion_seq_concat_fc' self.op_type = 'fusion_seqexpand_concat_fc'
self.lod = [[3, 5, 8, 2]] self.lod = [[3, 5, 8, 2]]
self.inputs_M = [15, 10, 10] self.inputs_M = [15, 10, 10]
self.D = 20 self.D = 20
Loading…
Cancel
Save