|
|
|
@ -198,9 +198,6 @@ class MKLDNNHandler {
|
|
|
|
|
mkldnn::engine engine_;
|
|
|
|
|
std::string key_;
|
|
|
|
|
std::string key_common_;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
static constexpr int MaxKeyLength = 256;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SumMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
@ -267,10 +264,9 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::string& unique_name)
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandler(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::ActivationMKLDNNHandler<T>::GetHash(
|
|
|
|
|
dims, algorithm, fmt, alpha, beta, unique_name)),
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::CreateKey(dims, algorithm, fmt, alpha,
|
|
|
|
|
beta, unique_name)),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
@ -288,10 +284,9 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::string& unique_name)
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandler(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::ActivationMKLDNNHandler<T>::GetHash(
|
|
|
|
|
dims, algorithm, fmt, alpha, beta, unique_name)),
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::CreateKey(dims, algorithm, fmt, alpha,
|
|
|
|
|
beta, unique_name)),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
@ -383,21 +378,6 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return eltwise_bwd_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(const memory::dims& input_dims,
|
|
|
|
|
const mkldnn::algorithm algorithm,
|
|
|
|
|
const MKLDNNMemoryFormat fmt, const float alpha,
|
|
|
|
|
const float beta, const std::string& suffix) {
|
|
|
|
|
std::string key;
|
|
|
|
|
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
|
|
|
|
|
platform::AppendKeyDims(&key, input_dims);
|
|
|
|
|
platform::AppendKey(&key, std::to_string(algorithm));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(fmt));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(alpha));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(beta));
|
|
|
|
|
platform::AppendKey(&key, suffix);
|
|
|
|
|
return key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
|
|
|
|
|
mkldnn::algorithm algorithm,
|
|
|
|
@ -597,22 +577,6 @@ class LRNMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return lrn_bwd_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(const memory::dims& input_dims, const int n,
|
|
|
|
|
const float alpha, const float beta, const float k,
|
|
|
|
|
const MKLDNNMemoryFormat& fmt,
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
std::string key;
|
|
|
|
|
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
|
|
|
|
|
platform::AppendKeyDims(&key, input_dims);
|
|
|
|
|
platform::AppendKey(&key, std::to_string(n));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(alpha));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(beta));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(k));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(fmt));
|
|
|
|
|
platform::AppendKey(&key, suffix);
|
|
|
|
|
return key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool is_test_;
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> fwd_pd_;
|
|
|
|
@ -790,24 +754,6 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return pooling_bwd_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(
|
|
|
|
|
const memory::dims& input_dims, const std::string& pooling_type,
|
|
|
|
|
const std::vector<int>& ksize, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, const memory::data_type& dt,
|
|
|
|
|
const MKLDNNMemoryFormat& fmt, const std::string& suffix) {
|
|
|
|
|
std::string key;
|
|
|
|
|
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
|
|
|
|
|
platform::AppendKeyDims(&key, input_dims);
|
|
|
|
|
platform::AppendKey(&key, pooling_type);
|
|
|
|
|
platform::AppendKeyDims(&key, ksize);
|
|
|
|
|
platform::AppendKeyDims(&key, strides);
|
|
|
|
|
platform::AppendKeyDims(&key, paddings);
|
|
|
|
|
platform::AppendKey(&key, std::to_string(dt));
|
|
|
|
|
platform::AppendKey(&key, std::to_string(fmt));
|
|
|
|
|
platform::AppendKey(&key, suffix);
|
|
|
|
|
return key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static inline int ComputeCeiledOutput(int input_size, int kernel_size,
|
|
|
|
|
int padding, int stride) {
|
|
|
|
@ -905,12 +851,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return transpose_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(std::vector<int>& shape, // NOLINT
|
|
|
|
|
std::vector<int>& axis, // NOLINT
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
return dims2str(shape) + dims2str(axis) + suffix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz, // NOLINT
|
|
|
|
|
std::vector<int>& axis // NOLINT
|
|
|
|
@ -999,14 +939,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return reorder_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(std::vector<int>& shape, // NOLINT
|
|
|
|
|
MKLDNNMemoryFormat in_fmt,
|
|
|
|
|
MKLDNNMemoryFormat out_fmt,
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
return dims2str(shape) + std::to_string(in_fmt) + "->" +
|
|
|
|
|
std::to_string(out_fmt) + "#" + suffix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<int> dims_;
|
|
|
|
|
framework::proto::VarType::Type vtype_;
|
|
|
|
@ -1346,58 +1278,6 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
|
|
|
|
|
return conv_bwd_data_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Generate keys for storing/retriving primitives for this operator
|
|
|
|
|
// TODO(jczaja): Make hashing function more optimial
|
|
|
|
|
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
|
|
|
|
|
mkldnn::memory::dims& weights_dims, // NOLINT
|
|
|
|
|
const std::string& fuse_activation, // NOLINT
|
|
|
|
|
std::vector<int>& strides, // NOLINT
|
|
|
|
|
std::vector<int>& paddings, // NOLINT
|
|
|
|
|
std::vector<int>& dilations, // NOLINT
|
|
|
|
|
int groups, const std::string& suffix) {
|
|
|
|
|
return dims2str(input_dims) + dims2str(weights_dims) + fuse_activation +
|
|
|
|
|
dims2str(strides) + dims2str(paddings) + dims2str(dilations) +
|
|
|
|
|
std::to_string(groups) + suffix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Generate keys for storing/retriving primitives for this operator
|
|
|
|
|
// TODO(jczaja): Make hashing function more optimial
|
|
|
|
|
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
|
|
|
|
|
mkldnn::memory::dims& weights_dims, // NOLINT
|
|
|
|
|
std::vector<int>& strides, // NOLINT
|
|
|
|
|
std::vector<int>& paddings, // NOLINT
|
|
|
|
|
std::vector<int>& dilations, // NOLINT
|
|
|
|
|
int groups, const std::string& suffix) {
|
|
|
|
|
return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
|
|
|
|
|
dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
|
|
|
|
|
suffix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CreateKey(
|
|
|
|
|
std::string* key, const mkldnn::memory::dims& input_dims,
|
|
|
|
|
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, const std::vector<int>& dilations,
|
|
|
|
|
const int& groups, const mkldnn::memory::data_type& srcdt,
|
|
|
|
|
const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
|
|
|
|
|
const bool& residual, const std::string& suffix) {
|
|
|
|
|
AppendKeyDims(key, input_dims);
|
|
|
|
|
|
|
|
|
|
AppendKeyDims(key, weights_dims);
|
|
|
|
|
|
|
|
|
|
AppendKeyDims(key, strides);
|
|
|
|
|
|
|
|
|
|
AppendKeyDims(key, paddings);
|
|
|
|
|
|
|
|
|
|
AppendKeyDims(key, dilations);
|
|
|
|
|
|
|
|
|
|
AppendKey(key, std::to_string(groups));
|
|
|
|
|
AppendKey(key, std::to_string(srcdt));
|
|
|
|
|
AppendKey(key, std::to_string(format));
|
|
|
|
|
AppendKey(key, fuse_activation);
|
|
|
|
|
AppendKey(key, std::to_string(residual));
|
|
|
|
|
AppendKey(key, suffix);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<typename forward_t::primitive_desc> conv_pd_;
|
|
|
|
|
std::shared_ptr<typename backward_weights_t::primitive_desc>
|
|
|
|
|