|
|
|
@ -53,25 +53,60 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
"Softmax input and output dimensions should match");
|
|
|
|
|
// Same memory descriptor to be used for input and output
|
|
|
|
|
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
|
|
|
|
|
// Currently only supports NC data format
|
|
|
|
|
// TODO(jczaja-intel): support more formats
|
|
|
|
|
auto softmax_md =
|
|
|
|
|
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
|
|
|
|
|
// Normalization is made after innermost dimension eg. C out of NC
|
|
|
|
|
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
|
|
|
|
|
softmax_md, 1 /*dim: C*/);
|
|
|
|
|
// create memory primitives
|
|
|
|
|
auto softmax_src_memory =
|
|
|
|
|
memory({softmax_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
auto softmax_dst_memory =
|
|
|
|
|
memory({softmax_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(const_cast<T*>(output_data)));
|
|
|
|
|
auto softmax_prim_desc =
|
|
|
|
|
softmax_forward::primitive_desc(softmax_desc, mkldnn_engine);
|
|
|
|
|
auto softmax = softmax_forward(softmax_prim_desc, softmax_src_memory,
|
|
|
|
|
softmax_dst_memory);
|
|
|
|
|
std::vector<primitive> pipeline{softmax};
|
|
|
|
|
// Generate keys for storing/retriving primitives for this operator
|
|
|
|
|
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
|
|
|
|
|
auto gethash = [](memory::dims& operand_dims) {
|
|
|
|
|
return std::string(std::to_string(operand_dims[0]) + "-" +
|
|
|
|
|
std::to_string(operand_dims[1]));
|
|
|
|
|
};
|
|
|
|
|
const std::string key = gethash(softmax_tz);
|
|
|
|
|
const std::string key_softmax_p = key + "@softmax_p";
|
|
|
|
|
const std::string key_softmax_src_mem_p = key + "@softmax_src_mem_p";
|
|
|
|
|
const std::string key_softmax_dst_mem_p = key + "@softmax_dst_mem_p";
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<void> softmax_p = dev_ctx.GetBlob(key_softmax_p);
|
|
|
|
|
if (softmax_p == nullptr) {
|
|
|
|
|
// Currently only NC data format is supported
|
|
|
|
|
auto softmax_md =
|
|
|
|
|
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
|
|
|
|
|
// Normalization is made after innermost dimension eg. C out of NC
|
|
|
|
|
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
|
|
|
|
|
softmax_md, 1 /*dim: C*/);
|
|
|
|
|
// create memory primitives
|
|
|
|
|
auto softmax_src_memory_p = std::make_shared<memory>(
|
|
|
|
|
memory::primitive_desc{softmax_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
dev_ctx.SetBlob(key_softmax_src_mem_p, softmax_src_memory_p);
|
|
|
|
|
auto softmax_dst_memory_p = std::make_shared<memory>(
|
|
|
|
|
memory::primitive_desc{softmax_md, mkldnn_engine},
|
|
|
|
|
static_cast<void*>(output_data));
|
|
|
|
|
dev_ctx.SetBlob(key_softmax_dst_mem_p, softmax_dst_memory_p);
|
|
|
|
|
|
|
|
|
|
auto softmax_forward_pd =
|
|
|
|
|
std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
|
|
|
|
|
mkldnn_engine);
|
|
|
|
|
softmax_p = std::make_shared<softmax_forward>(
|
|
|
|
|
*(softmax_forward_pd.get()),
|
|
|
|
|
*(static_cast<memory*>(softmax_src_memory_p.get())),
|
|
|
|
|
*(static_cast<memory*>(softmax_dst_memory_p.get())));
|
|
|
|
|
dev_ctx.SetBlob(key_softmax_p, softmax_p);
|
|
|
|
|
} else {
|
|
|
|
|
// Primitives already exist
|
|
|
|
|
auto src_memory_p = std::static_pointer_cast<memory>(
|
|
|
|
|
dev_ctx.GetBlob(key_softmax_src_mem_p));
|
|
|
|
|
PADDLE_ENFORCE(src_memory_p != nullptr,
|
|
|
|
|
"Fail to find softmax src mem_p in device context");
|
|
|
|
|
auto dst_memory_p = std::static_pointer_cast<memory>(
|
|
|
|
|
dev_ctx.GetBlob(key_softmax_dst_mem_p));
|
|
|
|
|
PADDLE_ENFORCE(dst_memory_p != nullptr,
|
|
|
|
|
"Fail to find softmax dst mem_p in device context");
|
|
|
|
|
src_memory_p->set_data_handle(
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
dst_memory_p->set_data_handle(output_data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline{
|
|
|
|
|
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|