|
|
|
@ -23,67 +23,56 @@ namespace operators {
|
|
|
|
|
using paddle::framework::Tensor;
|
|
|
|
|
using paddle::platform::MKLDNNDeviceContext;
|
|
|
|
|
|
|
|
|
|
struct MKLDNNMatrixSize final {
|
|
|
|
|
explicit MKLDNNMatrixSize(const std::vector<int>& in,
|
|
|
|
|
const std::vector<int>& w)
|
|
|
|
|
: mb{in[0]}, ic{in[1]}, oc{w[1]}, h{in[2]}, w{in[3]} {}
|
|
|
|
|
|
|
|
|
|
bool is_spatial() const { return h > 2 && w > 2; }
|
|
|
|
|
|
|
|
|
|
const int mb;
|
|
|
|
|
const int ic;
|
|
|
|
|
const int oc;
|
|
|
|
|
const int h, w;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MKLDNNMD {
|
|
|
|
|
public:
|
|
|
|
|
explicit MKLDNNMD(const T* in, const T* w, bool bias)
|
|
|
|
|
: sz_(std::unique_ptr<MKLDNNMatrixSize>(new MKLDNNMatrixSize(
|
|
|
|
|
paddle::framework::vectorize2int(in->dims()),
|
|
|
|
|
paddle::framework::vectorize2int(w->dims())))) {
|
|
|
|
|
: in{paddle::framework::vectorize2int(in->dims())},
|
|
|
|
|
w{paddle::framework::vectorize2int(w->dims())} {
|
|
|
|
|
with_bias_ = bias;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::desc dst() const {
|
|
|
|
|
return platform::MKLDNNMemDesc({sz_->mb, sz_->oc},
|
|
|
|
|
return platform::MKLDNNMemDesc({in[0], w[1]},
|
|
|
|
|
mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::nc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::desc src() const {
|
|
|
|
|
return sz_->is_spatial()
|
|
|
|
|
? platform::MKLDNNMemDesc({sz_->mb, sz_->ic, sz_->h, sz_->w},
|
|
|
|
|
return is_spatial()
|
|
|
|
|
? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]},
|
|
|
|
|
mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::nchw)
|
|
|
|
|
: platform::MKLDNNMemDesc({sz_->mb, sz_->ic},
|
|
|
|
|
: platform::MKLDNNMemDesc({in[0], in[1]},
|
|
|
|
|
mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::nc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::desc weights() const {
|
|
|
|
|
return sz_->is_spatial()
|
|
|
|
|
? platform::MKLDNNMemDesc({sz_->oc, sz_->ic, sz_->h, sz_->w},
|
|
|
|
|
return is_spatial()
|
|
|
|
|
? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]},
|
|
|
|
|
mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::oihw)
|
|
|
|
|
: platform::MKLDNNMemDesc({sz_->oc, sz_->ic},
|
|
|
|
|
: platform::MKLDNNMemDesc({w[1], in[1]},
|
|
|
|
|
mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::oi);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::desc bias() const {
|
|
|
|
|
return with_bias_
|
|
|
|
|
? platform::MKLDNNMemDesc({sz_->oc},
|
|
|
|
|
mkldnn::memory::data_type::f32,
|
|
|
|
|
? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::format_undef)
|
|
|
|
|
: platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32,
|
|
|
|
|
mkldnn::memory::format::format_undef);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<MKLDNNMatrixSize> sz_;
|
|
|
|
|
bool is_spatial() const { return in.size() > 1 && w.size() > 1; }
|
|
|
|
|
|
|
|
|
|
std::vector<int> in;
|
|
|
|
|
std::vector<int> w;
|
|
|
|
|
bool with_bias_;
|
|
|
|
|
bool is_spatial_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class MKLDNNMemory {
|
|
|
|
|