You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							310 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							310 lines
						
					
					
						
							12 KiB
						
					
					
				/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
   http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
 | 
						|
#include "paddle/fluid/framework/tensor.h"
 | 
						|
#include "paddle/fluid/operators/fc_op.h"
 | 
						|
#include "paddle/fluid/platform/device_context.h"
 | 
						|
#include "paddle/fluid/platform/mkldnn_helper.h"
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
namespace operators {
 | 
						|
 | 
						|
using paddle::framework::Tensor;
 | 
						|
using paddle::platform::MKLDNNDeviceContext;
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class MKLDNNMD {
 | 
						|
 public:
 | 
						|
  explicit MKLDNNMD(const T* in, const T* w, bool bias)
 | 
						|
      : in(paddle::framework::vectorize2int(in->dims())),
 | 
						|
        w(paddle::framework::vectorize2int(w->dims())) {
 | 
						|
    with_bias_ = bias;
 | 
						|
  }
 | 
						|
 | 
						|
  mkldnn::memory::desc dst() const {
 | 
						|
    return platform::MKLDNNMemDesc({in[0], w[1]},
 | 
						|
                                   mkldnn::memory::data_type::f32,
 | 
						|
                                   mkldnn::memory::format::nc);
 | 
						|
  }
 | 
						|
 | 
						|
  mkldnn::memory::desc src() const {
 | 
						|
    return is_spatial()
 | 
						|
               ? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]},
 | 
						|
                                         mkldnn::memory::data_type::f32,
 | 
						|
                                         mkldnn::memory::format::nchw)
 | 
						|
               : platform::MKLDNNMemDesc({in[0], in[1]},
 | 
						|
                                         mkldnn::memory::data_type::f32,
 | 
						|
                                         mkldnn::memory::format::nc);
 | 
						|
  }
 | 
						|
 | 
						|
  mkldnn::memory::desc weights() const {
 | 
						|
    return is_spatial()
 | 
						|
               ? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]},
 | 
						|
                                         mkldnn::memory::data_type::f32,
 | 
						|
                                         mkldnn::memory::format::oihw)
 | 
						|
               : platform::MKLDNNMemDesc({w[1], in[1]},
 | 
						|
                                         mkldnn::memory::data_type::f32,
 | 
						|
                                         mkldnn::memory::format::oi);
 | 
						|
  }
 | 
						|
 | 
						|
  mkldnn::memory::desc bias() const {
 | 
						|
    return with_bias_
 | 
						|
               ? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32,
 | 
						|
                                         mkldnn::memory::format::format_undef)
 | 
						|
               : platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32,
 | 
						|
                                         mkldnn::memory::format::format_undef);
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  bool is_spatial() const { return in.size() > 1 && w.size() > 1; }
 | 
						|
 | 
						|
  std::vector<int> in;
 | 
						|
  std::vector<int> w;
 | 
						|
  bool with_bias_;
 | 
						|
  bool is_spatial_;
 | 
						|
};
 | 
						|
 | 
						|
class MKLDNNMemory {
 | 
						|
 public:
 | 
						|
  MKLDNNMemory(MKLDNNMD<Tensor>* t, const mkldnn::engine& e)
 | 
						|
      : md_(t), engine_(e) {}
 | 
						|
  virtual ~MKLDNNMemory() = default;
 | 
						|
 | 
						|
  template <typename Output>
 | 
						|
  mkldnn::memory dst(const Output* out) {
 | 
						|
    return mkldnn::memory({md_->dst(), engine_},
 | 
						|
                          static_cast<void*>(const_cast<float*>(out)));
 | 
						|
  }
 | 
						|
 | 
						|
  template <typename Output>
 | 
						|
  mkldnn::memory dst(Output* out) {
 | 
						|
    return mkldnn::memory({md_->dst(), engine_}, out);
 | 
						|
  }
 | 
						|
 | 
						|
  template <typename Input>
 | 
						|
  mkldnn::memory src(const Input* in) {
 | 
						|
    return mkldnn::memory({md_->src(), engine_},
 | 
						|
                          static_cast<void*>(const_cast<float*>(in)));
 | 
						|
  }
 | 
						|
 | 
						|
  template <typename Weight>
 | 
						|
  mkldnn::memory weights(const Weight* w) {
 | 
						|
    return mkldnn::memory({md_->weights(), engine_},
 | 
						|
                          static_cast<void*>(const_cast<float*>(w)));
 | 
						|
  }
 | 
						|
 | 
						|
  mkldnn::memory bias() {
 | 
						|
    return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_));
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  MKLDNNMD<Tensor>* md_;
 | 
						|
  const mkldnn::engine& engine_;
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
 | 
						|
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
 | 
						|
                   "It must use CPUPlace.");
 | 
						|
 | 
						|
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
 | 
						|
    const auto& mkldnn_engine = dev_ctx.GetEngine();
 | 
						|
 | 
						|
    auto input = ctx.Input<Tensor>("Input");
 | 
						|
    auto w = ctx.Input<Tensor>("W");
 | 
						|
    auto bias = ctx.Input<Tensor>("Bias");
 | 
						|
 | 
						|
    PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
 | 
						|
                   "Input must be with 2 or 4 dimensions, i.e. NCHW");
 | 
						|
    // TODO(intel friends): the native weight format is io,
 | 
						|
    // but the mkldnn weight format is oihw, which may need be transposed.
 | 
						|
    PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
 | 
						|
                   "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
 | 
						|
 | 
						|
    bool with_bias = bias != nullptr;
 | 
						|
    MKLDNNMD<Tensor> md(input, w, with_bias);
 | 
						|
 | 
						|
    std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd =
 | 
						|
        FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
 | 
						|
                           with_bias, mkldnn_engine);
 | 
						|
 | 
						|
    const std::string key = ctx.op().Output("Out");
 | 
						|
    const std::string key_fc_pd = key + "@fc_pd";
 | 
						|
 | 
						|
    dev_ctx.SetBlob(key_fc_pd, pd);
 | 
						|
 | 
						|
    MKLDNNMemory mem(&md, mkldnn_engine);
 | 
						|
 | 
						|
    const T* input_data = input->data<T>();
 | 
						|
    const T* w_data = w->data<T>();
 | 
						|
 | 
						|
    auto output = ctx.Output<Tensor>("Out");
 | 
						|
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
 | 
						|
 | 
						|
    auto dst_memory = mem.dst(output_data);
 | 
						|
    auto src_memory = mem.src(input_data);
 | 
						|
    auto weights_memory = mem.weights(w_data);
 | 
						|
    // TODO(intel friends): bias memory should also be obtain from bias->data()
 | 
						|
    auto bias_memory = mem.bias();
 | 
						|
 | 
						|
    auto forward = with_bias ? mkldnn::inner_product_forward(
 | 
						|
                                   *pd, src_memory, weights_memory, bias_memory,
 | 
						|
                                   dst_memory)
 | 
						|
                             : mkldnn::inner_product_forward(
 | 
						|
                                   *pd, src_memory, weights_memory, dst_memory);
 | 
						|
 | 
						|
    std::vector<mkldnn::primitive> pipeline = {forward};
 | 
						|
    mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>
 | 
						|
  FcFwdPrimitiveDesc(const mkldnn::memory::desc& src,
 | 
						|
                     const mkldnn::memory::desc& weights,
 | 
						|
                     const mkldnn::memory::desc& dst,
 | 
						|
                     const mkldnn::memory::desc& bias, const bool with_bias,
 | 
						|
                     const mkldnn::engine& engine) const {
 | 
						|
    auto desc = with_bias
 | 
						|
                    ? mkldnn::inner_product_forward::desc(
 | 
						|
                          mkldnn::prop_kind::forward, src, weights, bias, dst)
 | 
						|
                    : mkldnn::inner_product_forward::desc(
 | 
						|
                          mkldnn::prop_kind::forward, src, weights, dst);
 | 
						|
 | 
						|
    auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine);
 | 
						|
    return std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>(pd);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
template <typename T>
 | 
						|
class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 | 
						|
 public:
 | 
						|
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
 | 
						|
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
 | 
						|
                   "It must use CPUPlace.");
 | 
						|
 | 
						|
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
 | 
						|
    const auto& mkldnn_engine = dev_ctx.GetEngine();
 | 
						|
 | 
						|
    T* input_grad_data = nullptr;
 | 
						|
    T* w_grad_data = nullptr;
 | 
						|
 | 
						|
    Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
 | 
						|
    Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
 | 
						|
 | 
						|
    if (input_grad) {
 | 
						|
      input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
 | 
						|
    }
 | 
						|
    if (w_grad) {
 | 
						|
      w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
 | 
						|
    }
 | 
						|
 | 
						|
    const Tensor* input = ctx.Input<Tensor>("Input");
 | 
						|
    const T* input_data = input->data<T>();
 | 
						|
 | 
						|
    const Tensor* w = ctx.Input<Tensor>("W");
 | 
						|
    const T* w_data = w->data<T>();
 | 
						|
 | 
						|
    const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
 | 
						|
    const T* out_grad_data = out_grad->data<T>();
 | 
						|
 | 
						|
    auto bias = ctx.Input<Tensor>("Bias");
 | 
						|
    bool with_bias = bias != nullptr;
 | 
						|
 | 
						|
    MKLDNNMD<Tensor> md(input, w, with_bias);
 | 
						|
    MKLDNNMemory mem(&md, mkldnn_engine);
 | 
						|
 | 
						|
    auto dst_memory = mem.dst(out_grad_data);
 | 
						|
    auto src_memory = mem.src(input_data);
 | 
						|
    auto weights_memory = mem.weights(w_data);
 | 
						|
    auto bias_memory = mem.bias();
 | 
						|
 | 
						|
    const std::string key = ctx.op().Input("Out");
 | 
						|
    const std::string key_fc_pd = key + "@fc_pd";
 | 
						|
 | 
						|
    auto pd =
 | 
						|
        std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
 | 
						|
            dev_ctx.GetBlob(key_fc_pd));
 | 
						|
 | 
						|
    PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context");
 | 
						|
 | 
						|
    if (w_grad) {
 | 
						|
      auto weights_grad_memory = mem.weights(w_grad_data);
 | 
						|
 | 
						|
      mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd =
 | 
						|
          FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
 | 
						|
                                    with_bias, *pd, mkldnn_engine);
 | 
						|
 | 
						|
      auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
 | 
						|
          bwd_weight_pd, src_memory, dst_memory, weights_grad_memory,
 | 
						|
          bias_memory);
 | 
						|
 | 
						|
      std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
 | 
						|
      mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
 | 
						|
    }
 | 
						|
 | 
						|
    if (input_grad) {
 | 
						|
      auto src_grad_memory = mem.src(input_grad_data);
 | 
						|
 | 
						|
      mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd =
 | 
						|
          FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd,
 | 
						|
                                 mkldnn_engine);
 | 
						|
 | 
						|
      auto bwd_data_prim = mkldnn::inner_product_backward_data(
 | 
						|
          bwd_data_pd, dst_memory, weights_memory, src_grad_memory);
 | 
						|
 | 
						|
      std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
 | 
						|
      mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  mkldnn::inner_product_backward_weights::primitive_desc
 | 
						|
  FcBwdWeightsPrimitiveDesc(
 | 
						|
      const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
 | 
						|
      const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias,
 | 
						|
      const bool with_bias,
 | 
						|
      const mkldnn::inner_product_forward::primitive_desc& pd,
 | 
						|
      const mkldnn::engine& engine) const {
 | 
						|
    auto bwd_weight_desc = with_bias
 | 
						|
                               ? mkldnn::inner_product_backward_weights::desc(
 | 
						|
                                     src, diff_weights, bias, diff_dst)
 | 
						|
                               : mkldnn::inner_product_backward_weights::desc(
 | 
						|
                                     src, diff_weights, bias, diff_dst);
 | 
						|
 | 
						|
    return mkldnn::inner_product_backward_weights::primitive_desc(
 | 
						|
        bwd_weight_desc, engine, pd);
 | 
						|
  }
 | 
						|
 | 
						|
  mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc(
 | 
						|
      const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
 | 
						|
      const mkldnn::memory::desc& diff_dst,
 | 
						|
      const mkldnn::inner_product_forward::primitive_desc& pd,
 | 
						|
      const mkldnn::engine& engine) const {
 | 
						|
    auto bwd_data_desc =
 | 
						|
        mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst);
 | 
						|
    return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc,
 | 
						|
                                                               engine, pd);
 | 
						|
  }
 | 
						|
};
 | 
						|
}  // namespace operators
 | 
						|
}  // namespace paddle
 | 
						|
 | 
						|
REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace,
 | 
						|
                   paddle::operators::FCMKLDNNOpKernel<float>);
 | 
						|
 | 
						|
REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace,
 | 
						|
                   paddle::operators::FCMKLDNNGradOpKernel<float>);
 |