|
|
|
@ -197,6 +197,130 @@ class MKLDNNHandler {
|
|
|
|
|
bool is_reusing_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class TransposeMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
public:
|
|
|
|
|
TransposeMKLDNNHandler(std::vector<int>& dims, std::vector<int>& axis,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
mkldnn::engine engine, const std::string& base_key)
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
|
|
|
|
|
dims_(dims),
|
|
|
|
|
axis_(axis),
|
|
|
|
|
logical_axis_(dims.size(), 0) {}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
|
const mkldnn::memory::format& fmt, void* ptr) {
|
|
|
|
|
auto local_key = key_ + "@user_src_mem_p";
|
|
|
|
|
auto mem_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
|
|
|
|
|
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
" find mem primitive in device context");
|
|
|
|
|
if (mem_p == nullptr) {
|
|
|
|
|
// Make memory descriptor using input format, unless it
|
|
|
|
|
// cannot be trusted (nchw) then make up memory fmt manually
|
|
|
|
|
for (size_t i = 0; i < logical_axis_.size(); ++i) {
|
|
|
|
|
logical_axis_[i] = i;
|
|
|
|
|
}
|
|
|
|
|
auto src_md = fmt != mkldnn::memory::format::nchw
|
|
|
|
|
? platform::MKLDNNMemDesc(
|
|
|
|
|
dims_, platform::MKLDNNGetDataType<float>(), fmt)
|
|
|
|
|
: Axis2MemoryDesc(dims_, logical_axis_);
|
|
|
|
|
mem_p = std::make_shared<mkldnn::memory>(
|
|
|
|
|
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
|
|
|
|
|
dev_ctx_.SetBlob(local_key, mem_p);
|
|
|
|
|
} else {
|
|
|
|
|
mem_p->set_data_handle(ptr);
|
|
|
|
|
// Mark that reusing happenned. All primitives from operator instance
|
|
|
|
|
// should be reused or none of them. So we check consistency
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return mem_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
|
|
|
|
|
platform::Place place) {
|
|
|
|
|
auto local_key = key_ + "@user_dst_mem_p";
|
|
|
|
|
auto mem_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
|
|
|
|
|
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
" find mem primitive in device context");
|
|
|
|
|
if (mem_p == nullptr) {
|
|
|
|
|
auto dst_mdp = mkldnn::memory::primitive_desc{
|
|
|
|
|
Axis2MemoryDesc(dims_, axis_), engine_};
|
|
|
|
|
|
|
|
|
|
auto dst_data = output->mutable_data<float>(
|
|
|
|
|
place, paddle::memory::Allocator::kDefault, dst_mdp.get_size());
|
|
|
|
|
|
|
|
|
|
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
|
|
|
|
|
dev_ctx_.SetBlob(local_key, mem_p);
|
|
|
|
|
} else {
|
|
|
|
|
auto dst_data = output->mutable_data<float>(place);
|
|
|
|
|
mem_p->set_data_handle(dst_data);
|
|
|
|
|
// Mark that reusing happenned. All primitives from operator instance
|
|
|
|
|
// should be reused or none of them. So we check consistency
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return mem_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
|
|
|
|
|
std::shared_ptr<mkldnn::memory> dst_memory_p,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory_p) {
|
|
|
|
|
auto prim_key = key_ + "@transpose_p";
|
|
|
|
|
auto transpose_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find convolution primitive in device context");
|
|
|
|
|
if (transpose_p == nullptr) {
|
|
|
|
|
transpose_p =
|
|
|
|
|
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, transpose_p);
|
|
|
|
|
} else {
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return transpose_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(std::vector<int>& shape, // NOLINT
|
|
|
|
|
std::vector<int>& axis, // NOLINT
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
return dims2str(shape) + dims2str(axis) + suffix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz,
|
|
|
|
|
std::vector<int>& axis) {
|
|
|
|
|
mkldnn_memory_desc_t mem_fmt;
|
|
|
|
|
|
|
|
|
|
mem_fmt.primitive_kind = mkldnn_memory;
|
|
|
|
|
mem_fmt.ndims = axis.size();
|
|
|
|
|
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
|
|
|
|
|
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
|
|
|
|
|
// regardless physical layout)
|
|
|
|
|
}
|
|
|
|
|
mem_fmt.data_type = mkldnn_f32;
|
|
|
|
|
mem_fmt.format = mkldnn_blocked;
|
|
|
|
|
|
|
|
|
|
unsigned int total_stride = 1;
|
|
|
|
|
for (int i = nchw_tz.size() - 1; i >= 0; --i) {
|
|
|
|
|
mem_fmt.layout_desc.blocking.padding_dims[i] =
|
|
|
|
|
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
|
|
|
|
|
// layout)
|
|
|
|
|
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
|
|
|
|
|
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
|
|
|
|
|
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
|
|
|
|
|
mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1;
|
|
|
|
|
total_stride *= nchw_tz[axis[i]];
|
|
|
|
|
}
|
|
|
|
|
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
|
|
|
|
|
return mem_fmt;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<int> dims_;
|
|
|
|
|
std::vector<int> axis_;
|
|
|
|
|
std::vector<int> logical_axis_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class forward_t, class backward_data_t, class backward_weights_t>
|
|
|
|
|
class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
|
|
|
|
|
public:
|
|
|
|
|