You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							207 lines
						
					
					
						
							7.8 KiB
						
					
					
				
			
		
		
	
	
							207 lines
						
					
					
						
							7.8 KiB
						
					
					
				/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
 | 
						|
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
 | 
						|
#include <string>
 | 
						|
#include "paddle/fluid/operators/math/blas.h"
 | 
						|
#include "paddle/fluid/operators/math/cpu_vec.h"
 | 
						|
#include "paddle/fluid/operators/math/fc_compute.h"
 | 
						|
#include "paddle/fluid/platform/cpu_info.h"
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
namespace operators {
 | 
						|
 | 
						|
void FusionSeqExpandConcatFCOp::InferShape(
 | 
						|
    framework::InferShapeContext* ctx) const {
 | 
						|
  PADDLE_ENFORCE_GT(
 | 
						|
      ctx->Inputs("X").size(), 1UL,
 | 
						|
      "Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
 | 
						|
  PADDLE_ENFORCE(
 | 
						|
      ctx->HasInput("FCWeight"),
 | 
						|
      "Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
 | 
						|
  PADDLE_ENFORCE(
 | 
						|
      ctx->HasOutput("Out"),
 | 
						|
      "Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
 | 
						|
  PADDLE_ENFORCE(
 | 
						|
      ctx->HasOutput("FCOut"),
 | 
						|
      "Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
 | 
						|
 | 
						|
  auto ins_dims = ctx->GetInputsDim("X");
 | 
						|
  auto w_dims = ctx->GetInputDim("FCWeight");  // (M0+M1+M2+..) x D
 | 
						|
  PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2.");
 | 
						|
  const int D = w_dims[1];
 | 
						|
  int sum = ins_dims[0][1];
 | 
						|
  for (size_t i = 1; i < ins_dims.size(); ++i) {
 | 
						|
    sum += ins_dims[i][1];
 | 
						|
  }
 | 
						|
  PADDLE_ENFORCE_EQ(sum, w_dims[0],
 | 
						|
                    "FC height should be sum of all inputs width.");
 | 
						|
  if (ctx->HasInput("FCBias")) {
 | 
						|
    auto b_dims = ctx->GetInputDim("FCBias");
 | 
						|
    PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2,
 | 
						|
                   "b_dims should be 1 or 2, get %d", b_dims.size());
 | 
						|
    if (b_dims.size() == 1) {
 | 
						|
      PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D);
 | 
						|
    } else {
 | 
						|
      PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D);
 | 
						|
      PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D);
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  ctx->SetOutputDim("Out", {ins_dims[0][0], D});
 | 
						|
  // fcout should be reshape when run since can not get lod in infershape
 | 
						|
  // explicit share the ref lod
 | 
						|
  ctx->ShareLoD("X", "Out", 0);
 | 
						|
}
 | 
						|
 | 
						|
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
 | 
						|
    const framework::ExecutionContext& ctx) const {
 | 
						|
  return framework::OpKernelType(
 | 
						|
      framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
 | 
						|
      ctx.device_context());
 | 
						|
}
 | 
						|
 | 
						|
void FusionSeqExpandConcatFCOpMaker::Make() {
 | 
						|
  AddInput("X",
 | 
						|
           "(LoDTensor) input LodDTensors, the first one must be have ref lod "
 | 
						|
           "for sequence expand, and the rest input should have same lod.")
 | 
						|
      .AsDuplicable();
 | 
						|
  AddInput("FCWeight", "(Tensor) the weights of fc.");
 | 
						|
  AddInput("FCBias", "(Tensor, optional) the bias of fc.").AsDispensable();
 | 
						|
  AddOutput("Out", "(LoDTensor) Output LodTensor.");
 | 
						|
  AddOutput(
 | 
						|
      "FCOut",
 | 
						|
      "(Tensor) the intermediate tensor to keep the result of fc."
 | 
						|
      "Shape is (N x D), where N is the batch size, D is the output dim of fc")
 | 
						|
      .AsIntermediate();
 | 
						|
  AddAttr<std::string>("fc_activation",
 | 
						|
                       "(string, default: identity)"
 | 
						|
                       "The activation for the result of fc."
 | 
						|
                       "`identity` by default.")
 | 
						|
      .SetDefault("identity")
 | 
						|
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
 | 
						|
  AddComment(R"DOC(
 | 
						|
Fusion Sequence expand + concat + fc Operator.
 | 
						|
 | 
						|
All below conditions should be meet:
 | 
						|
 | 
						|
The ref_level of seq_expand should be 0.
 | 
						|
 | 
						|
The ref lod of seq_expand level is the first input of concat.
 | 
						|
 | 
						|
The other inputs should have same lod and same batch size of ref lod.
 | 
						|
 | 
						|
The seq len of other inputs should be 1.
 | 
						|
 | 
						|
The concat axis should be 1.
 | 
						|
 | 
						|
)DOC");
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
						|
    using DeviceContext = paddle::platform::CPUDeviceContext;
 | 
						|
    auto ins = ctx.MultiInput<LoDTensor>("X");
 | 
						|
    auto* w = ctx.Input<Tensor>("FCWeight");
 | 
						|
    auto* b = ctx.Input<Tensor>("FCBias");
 | 
						|
    auto* out = ctx.Output<LoDTensor>("Out");
 | 
						|
    auto* fc_out = ctx.Output<Tensor>("FCOut");
 | 
						|
 | 
						|
    auto* ref_in = ins[0];
 | 
						|
    auto ref_lod = ref_in->lod();
 | 
						|
    auto in1_lod = ins[1]->lod();
 | 
						|
    auto ref_dims = ref_in->dims();  // T x M0
 | 
						|
    auto in1_dims = ins[1]->dims();  // N x M1
 | 
						|
    auto w_dims = w->dims();
 | 
						|
    const int N = ref_lod[0].size() - 1;
 | 
						|
    const int total_T = ref_dims[0];
 | 
						|
    const int M0 = ref_dims[1];
 | 
						|
    const int M1 = in1_dims[1];
 | 
						|
    const int D = w_dims[1];
 | 
						|
 | 
						|
    // some check and fcout should be reshape here
 | 
						|
    // since infershape can not get lod info
 | 
						|
    PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
 | 
						|
    PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
 | 
						|
    PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
 | 
						|
                      "Batch size of all inputs should be equal.");
 | 
						|
    PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
 | 
						|
                      "Seq_length of other inputs should be 1.");
 | 
						|
    PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
 | 
						|
    for (size_t i = 2; i < ins.size(); ++i) {
 | 
						|
      PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
 | 
						|
                        "All other inputs height should be equal");
 | 
						|
      PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
 | 
						|
                        "All other inputs should have same lod");
 | 
						|
    }
 | 
						|
    fc_out->Resize({N, D});
 | 
						|
 | 
						|
    std::function<void(const int, const T*, T*)> fc_act;
 | 
						|
    auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
 | 
						|
    if (platform::jit::MayIUse(platform::jit::avx)) {
 | 
						|
      math::VecActivations<T, platform::jit::avx> act_functor;
 | 
						|
      fc_act = act_functor(fc_act_str);
 | 
						|
    } else {
 | 
						|
      math::VecActivations<T, platform::jit::isa_any> act_functor;
 | 
						|
      fc_act = act_functor(fc_act_str);
 | 
						|
    }
 | 
						|
 | 
						|
    const T* ref_in_data = ref_in->data<T>();
 | 
						|
    const T* in1_data = ins[1]->data<T>();
 | 
						|
    const T* w_data = w->data<T>();
 | 
						|
    T* out_data = out->mutable_data<T>(ctx.GetPlace());
 | 
						|
    T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
 | 
						|
 | 
						|
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
 | 
						|
    math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
 | 
						|
                                      out_data, b ? b->data<T>() : NULL);
 | 
						|
    w_data = w_data + M0 * D;
 | 
						|
    // first write on
 | 
						|
    blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
 | 
						|
    w_data = w_data + M1 * D;
 | 
						|
    for (size_t i = 2; i < ins.size(); ++i) {
 | 
						|
      // add on
 | 
						|
      const T* in_data = ins[i]->data<T>();
 | 
						|
      const int K = ins[i]->dims()[1];
 | 
						|
      blas.GEMM(CblasNoTrans, CblasNoTrans, N, D, K, static_cast<T>(1), in_data,
 | 
						|
                K, w_data, D, static_cast<T>(1), fc_out_data, D);
 | 
						|
      w_data = w_data + K * D;
 | 
						|
    }
 | 
						|
    T* cur_out_data = out_data;
 | 
						|
    for (int i = 0; i < N; ++i) {
 | 
						|
      int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
 | 
						|
      T* src = fc_out_data + i * D;
 | 
						|
      for (int step = 0; step < seq_len; ++step) {
 | 
						|
        blas.VADD(D, cur_out_data, src, cur_out_data);
 | 
						|
        cur_out_data = cur_out_data + D;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    fc_act(total_T * D, out_data, out_data);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
}  // namespace operators
 | 
						|
}  // namespace paddle
 | 
						|
 | 
						|
namespace ops = paddle::operators;
 | 
						|
REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
 | 
						|
                  ops::FusionSeqExpandConcatFCOpMaker,
 | 
						|
                  paddle::framework::DefaultGradOpDescMaker<true>);
 | 
						|
 | 
						|
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
 | 
						|
                       ops::FusionSeqExpandConcatFCOpKernel<float>,
 | 
						|
                       ops::FusionSeqExpandConcatFCOpKernel<double>);
 |