|
|
@ -18,6 +18,26 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using mkldnn::memory; // Note: paddle has also "memory" namespace
|
|
|
|
|
|
|
|
using mkldnn::pooling_forward;
|
|
|
|
|
|
|
|
using mkldnn::pooling_backward;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Generate keys for storing/retriving primitives for this operator
|
|
|
|
|
|
|
|
// TODO(jczaja): Make hashing function more optimial
|
|
|
|
|
|
|
|
static std::string gethash(memory::dims& input_dims, std::string& pooling_type,
|
|
|
|
|
|
|
|
std::vector<int>& ksize, std::vector<int>& strides,
|
|
|
|
|
|
|
|
std::vector<int>& paddings, std::string suffix) {
|
|
|
|
|
|
|
|
auto dims2str = [](memory::dims& operand_dims) {
|
|
|
|
|
|
|
|
std::string dstr = "";
|
|
|
|
|
|
|
|
for (size_t i = 0; i < operand_dims.size(); ++i) {
|
|
|
|
|
|
|
|
dstr += std::to_string(operand_dims[i]) + "-";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return dstr;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
|
|
|
|
|
|
|
|
dims2str(paddings) + pooling_type + suffix;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
// Get an unique name from "argument" name of "Out" variable
|
|
|
|
// Get an unique name from "argument" name of "Out" variable
|
|
|
|
// This name will be used as key when saving info into device context
|
|
|
|
// This name will be used as key when saving info into device context
|
|
|
|
const std::string key = ctx.op().Output("Out");
|
|
|
|
|
|
|
|
const std::string key_pool_pd = key + "@pool_pd";
|
|
|
|
|
|
|
|
const std::string key_pool_workspace_memory =
|
|
|
|
|
|
|
|
key + "@pool_workspace_memory";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
|
|
|
|
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
|
|
|
|
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
|
|
|
|
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
|
|
|
@ -63,13 +79,28 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
|
|
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
|
|
|
|
|
|
|
|
paddings, ctx.op().Output("Out"));
|
|
|
|
|
|
|
|
const std::string key_pool_p = key + "@pool_p";
|
|
|
|
|
|
|
|
const std::string key_pool_pd = key + "@pool_pd";
|
|
|
|
|
|
|
|
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
|
|
|
|
|
|
|
|
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
|
|
|
|
|
|
|
|
const std::string key_pool_workspace_memory =
|
|
|
|
|
|
|
|
key + "@pool_workspace_memory";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto pool_p =
|
|
|
|
|
|
|
|
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
|
|
|
|
|
|
|
|
if (pool_p == nullptr) {
|
|
|
|
// TODO(pzelazko-intel): support more formats
|
|
|
|
// TODO(pzelazko-intel): support more formats
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
|
|
|
|
|
|
|
|
|
|
|
auto src_md =
|
|
|
|
|
|
|
|
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
|
|
|
|
auto dst_md =
|
|
|
|
|
|
|
|
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
|
|
|
|
std::shared_ptr<pooling_forward::primitive_desc> pool_pd =
|
|
|
|
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
|
|
|
|
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
|
|
|
|
pooling_type, mkldnn_engine);
|
|
|
|
pooling_type, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
|
@ -82,18 +113,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
// save pool_workspace_memory to be referred in backward path
|
|
|
|
// save pool_workspace_memory to be referred in backward path
|
|
|
|
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
|
|
|
|
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
|
|
|
|
|
|
|
|
|
|
|
|
auto src_memory =
|
|
|
|
auto pool_src_memory_p = std::make_shared<memory>(
|
|
|
|
mkldnn::memory({src_md, mkldnn_engine},
|
|
|
|
memory::primitive_desc{src_md, mkldnn_engine},
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
auto dst_memory =
|
|
|
|
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p);
|
|
|
|
mkldnn::memory({dst_md, mkldnn_engine},
|
|
|
|
|
|
|
|
static_cast<void*>(const_cast<T*>(output_data)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto pool_prim = mkldnn::pooling_forward(*pool_pd, src_memory, dst_memory,
|
|
|
|
auto pool_dst_memory_p = std::make_shared<memory>(
|
|
|
|
|
|
|
|
memory::primitive_desc{dst_md, mkldnn_engine},
|
|
|
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pool_p = std::make_shared<pooling_forward>(
|
|
|
|
|
|
|
|
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()),
|
|
|
|
*workspace_memory);
|
|
|
|
*workspace_memory);
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_pool_p, pool_p);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// Primitives already exist
|
|
|
|
|
|
|
|
auto pool_src_memory_p =
|
|
|
|
|
|
|
|
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(pool_src_memory_p != nullptr,
|
|
|
|
|
|
|
|
"Fail to find pooling src mem_p in device context");
|
|
|
|
|
|
|
|
auto pool_dst_memory_p =
|
|
|
|
|
|
|
|
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
|
|
|
|
|
|
|
|
"Fail to find pooling dst mem_p in device context");
|
|
|
|
|
|
|
|
pool_src_memory_p->set_data_handle(
|
|
|
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
|
|
|
pool_dst_memory_p->set_data_handle(output_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
std::vector<mkldnn::primitive> pipeline{pool_prim};
|
|
|
|
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -120,8 +170,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
mkldnn::memory::primitive_desc workspace_md =
|
|
|
|
mkldnn::memory::primitive_desc workspace_md =
|
|
|
|
pooling_type == "max"
|
|
|
|
pooling_type == "max"
|
|
|
|
? pool_pd->workspace_primitive_desc()
|
|
|
|
? pool_pd->workspace_primitive_desc()
|
|
|
|
: mkldnn::memory::primitive_desc(
|
|
|
|
: mkldnn::memory::primitive_desc({{},
|
|
|
|
{{}, mkldnn::memory::f32, mkldnn::memory::format::nchw},
|
|
|
|
platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
|
|
|
mkldnn::memory::format::nchw},
|
|
|
|
engine);
|
|
|
|
engine);
|
|
|
|
|
|
|
|
|
|
|
|
auto p_workspace_memory = new mkldnn::memory(workspace_md);
|
|
|
|
auto p_workspace_memory = new mkldnn::memory(workspace_md);
|
|
|
@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
|
|
// Get an unique name from "argument" name of "Out" variable
|
|
|
|
|
|
|
|
// This name will be used as key when referring info from device context
|
|
|
|
|
|
|
|
const std::string key = ctx.op().Input("Out");
|
|
|
|
|
|
|
|
const std::string key_pool_pd = key + "@pool_pd";
|
|
|
|
|
|
|
|
const std::string key_pool_workspace_memory =
|
|
|
|
|
|
|
|
key + "@pool_workspace_memory";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
|
|
|
|
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
|
|
|
|
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
|
|
|
|
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
|
|
|
|
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
|
|
|
|
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
|
|
|
@ -171,11 +215,26 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
std::vector<int> diff_dst_tz =
|
|
|
|
std::vector<int> diff_dst_tz =
|
|
|
|
paddle::framework::vectorize2int(out_grad->dims());
|
|
|
|
paddle::framework::vectorize2int(out_grad->dims());
|
|
|
|
|
|
|
|
|
|
|
|
auto diff_src_md = platform::MKLDNNMemDesc(diff_src_tz, mkldnn::memory::f32,
|
|
|
|
// Get an unique name from "argument" name of "Out" variable
|
|
|
|
|
|
|
|
// This name will be used as key when referring info from device context
|
|
|
|
|
|
|
|
const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides,
|
|
|
|
|
|
|
|
paddings, ctx.op().Input("Out"));
|
|
|
|
|
|
|
|
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
|
|
|
|
|
|
|
|
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
|
|
|
|
|
|
|
|
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
|
|
|
|
|
|
|
|
const std::string key_pool_pd = key + "@pool_pd";
|
|
|
|
|
|
|
|
const std::string key_pool_workspace_memory =
|
|
|
|
|
|
|
|
key + "@pool_workspace_memory";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
|
|
|
|
|
|
|
|
dev_ctx.GetBlob(key_pool_bwd_p));
|
|
|
|
|
|
|
|
if (pool_bwd_p == nullptr) {
|
|
|
|
|
|
|
|
auto diff_src_md =
|
|
|
|
|
|
|
|
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(diff_dst_tz, mkldnn::memory::f32,
|
|
|
|
auto diff_dst_md =
|
|
|
|
|
|
|
|
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
|
|
// Retrieve pool_pd/pool_workspace_memory from device context
|
|
|
|
// Retrieve pool_pd/pool_workspace_memory from device context
|
|
|
|
auto pool_pd =
|
|
|
|
auto pool_pd =
|
|
|
|
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
|
|
|
|
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
|
|
|
@ -188,6 +247,15 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
PADDLE_ENFORCE(workspace_memory != nullptr,
|
|
|
|
PADDLE_ENFORCE(workspace_memory != nullptr,
|
|
|
|
"Fail to find workspace_memory in device context");
|
|
|
|
"Fail to find workspace_memory in device context");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto pool_diff_src_memory_p = std::make_shared<memory>(memory(
|
|
|
|
|
|
|
|
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data)));
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto pool_diff_dst_memory_p = std::make_shared<memory>(
|
|
|
|
|
|
|
|
memory({diff_dst_md, mkldnn_engine},
|
|
|
|
|
|
|
|
static_cast<void*>(const_cast<T*>(out_grad_data))));
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
|
|
|
|
|
|
|
|
|
|
|
|
auto pool_bwd_desc = mkldnn::pooling_backward::desc(
|
|
|
|
auto pool_bwd_desc = mkldnn::pooling_backward::desc(
|
|
|
|
pooling_type == "max" ? mkldnn::algorithm::pooling_max
|
|
|
|
pooling_type == "max" ? mkldnn::algorithm::pooling_max
|
|
|
|
: mkldnn::algorithm::pooling_avg,
|
|
|
|
: mkldnn::algorithm::pooling_avg,
|
|
|
@ -196,18 +264,27 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
|
|
|
|
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
|
|
|
|
pool_bwd_desc, mkldnn_engine, *pool_pd);
|
|
|
|
pool_bwd_desc, mkldnn_engine, *pool_pd);
|
|
|
|
|
|
|
|
|
|
|
|
auto diff_src_memory =
|
|
|
|
pool_bwd_p = std::make_shared<pooling_backward>(
|
|
|
|
mkldnn::memory({diff_src_md, mkldnn_engine},
|
|
|
|
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory,
|
|
|
|
static_cast<void*>(const_cast<T*>(in_x_grad_data)));
|
|
|
|
*(pool_diff_src_memory_p));
|
|
|
|
auto diff_dst_memory =
|
|
|
|
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
|
|
|
|
mkldnn::memory({diff_dst_md, mkldnn_engine},
|
|
|
|
} else {
|
|
|
|
static_cast<void*>(const_cast<T*>(out_grad_data)));
|
|
|
|
// Primitives already exist
|
|
|
|
|
|
|
|
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>(
|
|
|
|
auto bwd_prim = mkldnn::pooling_backward(
|
|
|
|
dev_ctx.GetBlob(key_pool_diff_src_mem_p));
|
|
|
|
pool_bwd_pd, diff_dst_memory, *workspace_memory, diff_src_memory);
|
|
|
|
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr,
|
|
|
|
|
|
|
|
"Fail to find pooling src mem_p in device context");
|
|
|
|
|
|
|
|
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>(
|
|
|
|
|
|
|
|
dev_ctx.GetBlob(key_pool_diff_dst_mem_p));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr,
|
|
|
|
|
|
|
|
"Fail to find pooling dst mem_p in device context");
|
|
|
|
|
|
|
|
pool_diff_src_memory_p->set_data_handle(
|
|
|
|
|
|
|
|
reinterpret_cast<void*>(in_x_grad_data));
|
|
|
|
|
|
|
|
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
std::vector<mkldnn::primitive> pipeline{bwd_prim};
|
|
|
|
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())};
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
} // Compute()
|
|
|
|
} // Compute()
|
|
|
|
};
|
|
|
|
};
|
|
|
|