|
|
|
@ -13,7 +13,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/activation_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_reuse.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -99,20 +99,21 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
auto src_format =
|
|
|
|
|
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
|
|
|
|
|
|
|
|
|
|
const std::string key = gethash(src_tz, algorithm);
|
|
|
|
|
const std::string key_src_data =
|
|
|
|
|
key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
|
|
|
|
|
const std::string key_src_layout =
|
|
|
|
|
key + ctx.op().Output("Out") + "@eltwise_fwd_src_layout";
|
|
|
|
|
const std::string key_with_layout = key + std::to_string(src_format);
|
|
|
|
|
const std::string key_src_mem = key_with_layout + "@eltwise_fwd_src_mem";
|
|
|
|
|
const std::string key_dst_mem = key_with_layout + "@eltwise_fwd_dst_mem";
|
|
|
|
|
const std::string key_fwd = key_with_layout + "@eltwise_fwd";
|
|
|
|
|
const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd";
|
|
|
|
|
|
|
|
|
|
bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key
|
|
|
|
|
// with alpha, beta
|
|
|
|
|
std::string key = platform::MKLDNNHandler::GetHash(
|
|
|
|
|
src_tz, std::to_string(algorithm) + ctx.op().Output("Out"));
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Make it Thread safe
|
|
|
|
|
// save input data and layout to be referred in backward path
|
|
|
|
|
const std::string key_src_data = key + "@eltwise_fwd_src_data";
|
|
|
|
|
const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
|
|
|
|
|
// Just in case some int8 models are run interchangebly
|
|
|
|
|
// with float models then format maybe diffrent
|
|
|
|
|
key += std::to_string(src_format);
|
|
|
|
|
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
|
|
|
|
|
auto p_src_data = std::make_shared<const T *>(x_data);
|
|
|
|
|
auto p_src_layout = std::make_shared<memory::format>(src_format);
|
|
|
|
|
if (!is_test) {
|
|
|
|
@ -120,65 +121,34 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
dev_ctx.SetBlob(key_src_layout, p_src_layout);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
|
|
|
|
|
dev_ctx.GetBlob(key_fwd));
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> dst_memory;
|
|
|
|
|
|
|
|
|
|
if (p_fwd == nullptr) {
|
|
|
|
|
// create mkldnn memory for input X
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
|
|
|
|
|
auto src_memory = std::shared_ptr<memory>(
|
|
|
|
|
new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
|
|
|
|
|
// save src_memory to be referred in backward path
|
|
|
|
|
dev_ctx.SetBlob(key_src_mem, src_memory);
|
|
|
|
|
|
|
|
|
|
// create primitive descriptor for activation forward and save it
|
|
|
|
|
auto mkldnn_forward_prop_kind = is_test
|
|
|
|
|
? mkldnn::prop_kind::forward_inference
|
|
|
|
|
: mkldnn::prop_kind::forward_training;
|
|
|
|
|
auto forward_desc = mkldnn::eltwise_forward::desc(
|
|
|
|
|
mkldnn_forward_prop_kind, algorithm,
|
|
|
|
|
src_memory->get_primitive_desc().desc(), alpha, beta);
|
|
|
|
|
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
forward_desc, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
// save prim desc into global device context to be referred in backward path
|
|
|
|
|
if (!is_test) dev_ctx.SetBlob(key_fwd_pd, forward_pd);
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory for output y
|
|
|
|
|
dst_memory =
|
|
|
|
|
std::make_shared<memory>(forward_pd->dst_primitive_desc(), y_data);
|
|
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_dst_mem, dst_memory);
|
|
|
|
|
|
|
|
|
|
// create activation primitive
|
|
|
|
|
p_fwd = std::make_shared<mkldnn::eltwise_forward>(*forward_pd, *src_memory,
|
|
|
|
|
*dst_memory);
|
|
|
|
|
dev_ctx.SetBlob(key_fwd, p_fwd);
|
|
|
|
|
} else {
|
|
|
|
|
// primitives already exist
|
|
|
|
|
auto src_memory =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
|
|
|
|
|
PADDLE_ENFORCE(src_memory != nullptr,
|
|
|
|
|
"Fail to find eltwise src_memory in device context.");
|
|
|
|
|
dst_memory =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
|
|
|
|
|
PADDLE_ENFORCE(dst_memory != nullptr,
|
|
|
|
|
"Fail to find eltwise dst_memory in device context.");
|
|
|
|
|
|
|
|
|
|
src_memory->set_data_handle(platform::to_void_cast(x_data));
|
|
|
|
|
dst_memory->set_data_handle(y_data);
|
|
|
|
|
platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
src_format);
|
|
|
|
|
|
|
|
|
|
auto activation_pd = handler.AcquireActivationPrimitiveDescriptor(
|
|
|
|
|
is_test ? mkldnn::prop_kind::forward_inference
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
algorithm, md, alpha, beta);
|
|
|
|
|
|
|
|
|
|
auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data));
|
|
|
|
|
// jczaja: Workaround, src_memory_p is needed in BWD so it has
|
|
|
|
|
// to be accessible under key not dependant on TID
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
dev_ctx.SetBlob(key_src_mem, src_memory_p);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dst_memory_p =
|
|
|
|
|
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(y_data));
|
|
|
|
|
auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);
|
|
|
|
|
|
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
pipeline.push_back(*p_fwd);
|
|
|
|
|
pipeline.push_back(*activation_p);
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
y->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
y->set_format(GetMKLDNNFormat(*dst_memory));
|
|
|
|
|
y->set_format(GetMKLDNNFormat(*dst_memory_p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -199,90 +169,51 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
auto diff_y_format =
|
|
|
|
|
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
|
|
|
|
|
|
|
|
|
|
const std::string key = gethash(diff_dst_tz, algorithm);
|
|
|
|
|
const std::string key_src_data =
|
|
|
|
|
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
|
|
|
|
|
const std::string key_src_layout =
|
|
|
|
|
key + ctx.op().Input("Out") + "@eltwise_fwd_src_layout";
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
|
|
|
|
|
|
|
|
|
|
std::string key = platform::MKLDNNHandler::GetHash(
|
|
|
|
|
diff_dst_tz, std::to_string(algorithm) + ctx.op().Input("Out"));
|
|
|
|
|
|
|
|
|
|
const std::string key_src_data = key + "@eltwise_fwd_src_data";
|
|
|
|
|
const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
|
|
|
|
|
|
|
|
|
|
// Get Data from FWD op
|
|
|
|
|
const auto p_src_layout =
|
|
|
|
|
std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
|
|
|
|
|
const std::string key_src_mem =
|
|
|
|
|
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
|
|
|
|
|
const std::string key_fwd_pd =
|
|
|
|
|
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
|
|
|
|
|
const std::string key_with_layouts =
|
|
|
|
|
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
|
|
|
|
|
const std::string key_diff_src_mem =
|
|
|
|
|
key_with_layouts + "@eltwise_diff_src_mem";
|
|
|
|
|
const std::string key_diff_dst_mem =
|
|
|
|
|
key_with_layouts + "@eltwise_diff_dst_mem";
|
|
|
|
|
const std::string key_grad = key_with_layouts + "@eltwise_grad";
|
|
|
|
|
|
|
|
|
|
const auto p_src_data =
|
|
|
|
|
std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
|
|
|
|
|
|
|
|
|
|
key += std::to_string(*p_src_layout);
|
|
|
|
|
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
|
|
|
|
|
auto src_memory =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
|
|
|
|
|
PADDLE_ENFORCE(src_memory != nullptr,
|
|
|
|
|
"Fail to find src_memory in device context");
|
|
|
|
|
src_memory->set_data_handle(*p_src_data);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> diff_src_memory;
|
|
|
|
|
|
|
|
|
|
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_backward>(
|
|
|
|
|
dev_ctx.GetBlob(key_grad));
|
|
|
|
|
|
|
|
|
|
if (p_grad == nullptr) {
|
|
|
|
|
// create mkldnn memory for input diff_y
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
|
|
|
|
|
auto diff_dst_memory = std::shared_ptr<memory>(
|
|
|
|
|
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
|
|
|
|
|
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
|
|
|
|
|
|
|
|
|
|
// retrieve eltwise primitive desc from device context
|
|
|
|
|
auto forward_pd =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx.GetBlob(key_fwd_pd));
|
|
|
|
|
PADDLE_ENFORCE(forward_pd != nullptr,
|
|
|
|
|
"Fail to find eltwise_fwd_pd in device context");
|
|
|
|
|
|
|
|
|
|
// ceate primitive descriptor for activation backward
|
|
|
|
|
auto backward_desc = mkldnn::eltwise_backward::desc(
|
|
|
|
|
algorithm, diff_dst_memory->get_primitive_desc().desc(),
|
|
|
|
|
src_memory->get_primitive_desc().desc(), alpha, beta);
|
|
|
|
|
auto backward_pd = mkldnn::eltwise_backward::primitive_desc(
|
|
|
|
|
backward_desc, mkldnn_engine, *forward_pd);
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory for output diff_src
|
|
|
|
|
diff_src_memory = std::make_shared<memory>(
|
|
|
|
|
backward_pd.diff_src_primitive_desc(), diff_x_data);
|
|
|
|
|
dev_ctx.SetBlob(key_diff_src_mem, diff_src_memory);
|
|
|
|
|
|
|
|
|
|
// create activation backward primitive
|
|
|
|
|
p_grad = std::make_shared<mkldnn::eltwise_backward>(
|
|
|
|
|
backward_pd, *src_memory, *diff_dst_memory, *diff_src_memory);
|
|
|
|
|
dev_ctx.SetBlob(key_grad, p_grad);
|
|
|
|
|
} else {
|
|
|
|
|
// primitives already exist
|
|
|
|
|
diff_src_memory = std::static_pointer_cast<mkldnn::memory>(
|
|
|
|
|
dev_ctx.GetBlob(key_diff_src_mem));
|
|
|
|
|
auto diff_dst_memory = std::static_pointer_cast<mkldnn::memory>(
|
|
|
|
|
dev_ctx.GetBlob(key_diff_dst_mem));
|
|
|
|
|
|
|
|
|
|
diff_src_memory->set_data_handle(
|
|
|
|
|
platform::to_void_reinterpret_cast(diff_x_data));
|
|
|
|
|
diff_dst_memory->set_data_handle(
|
|
|
|
|
platform::to_void_reinterpret_cast(diff_y_data));
|
|
|
|
|
}
|
|
|
|
|
platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto diff_dst_memory_p =
|
|
|
|
|
handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast<T>(diff_y_data));
|
|
|
|
|
|
|
|
|
|
auto activation_backward_pd =
|
|
|
|
|
handler.AcquireActivationBackwardPrimitiveDescriptor(
|
|
|
|
|
algorithm, diff_dst_md, src_memory->get_primitive_desc().desc(),
|
|
|
|
|
alpha, beta);
|
|
|
|
|
|
|
|
|
|
auto diff_src_memory_p =
|
|
|
|
|
handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data);
|
|
|
|
|
|
|
|
|
|
auto activation_backward_p = handler.AcquireActivationBackward(
|
|
|
|
|
diff_src_memory_p, diff_dst_memory_p, src_memory);
|
|
|
|
|
|
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
pipeline.push_back(*p_grad);
|
|
|
|
|
pipeline.push_back(*activation_backward_p);
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
diff_x->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
|
|
|
|
|
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, mkldnn::algorithm algorithm>
|
|
|
|
|