rename fusion seq_concat_fc to fusion seqexpand_concat_fc

createGenDocLib
tensor-tang 7 years ago
parent 0f0d48230c
commit 02909335e9

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h"
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
@ -22,15 +22,20 @@ limitations under the License. */
namespace paddle {
namespace operators {
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqConcatFCOp should larger than 1.");
PADDLE_ENFORCE(ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqConcatFC should not be null.");
void FusionSeqExpandConcatFCOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT(
ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
PADDLE_ENFORCE(
ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Out", 0);
}
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionSeqConcatFCOpMaker::Make() {
void FusionSeqExpandConcatFCOpMaker::Make() {
AddInput("X",
"(LoDTensor) input LodDTensors, the first one must be have ref lod "
"for sequence expand, and the rest input should have same lod.")
@ -100,7 +105,7 @@ The concat axis should be 1.
}
template <typename T>
class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp,
ops::FusionSeqConcatFCOpMaker,
REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqExpandConcatFCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc,
ops::FusionSeqConcatFCKernel<float>,
ops::FusionSeqConcatFCKernel<double>);
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqExpandConcatFCOpKernel<float>,
ops::FusionSeqExpandConcatFCOpKernel<double>);

@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqConcatFCOp : public framework::OperatorWithKernel {
class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker {
class FusionSeqExpandConcatFCOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};

@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest):
pass
def setUp(self):
self.op_type = 'fusion_seq_concat_fc'
self.op_type = 'fusion_seqexpand_concat_fc'
self.lod = [[3, 5, 8, 2]]
self.inputs_M = [15, 10, 10]
self.D = 20
Loading…
Cancel
Save