|
|
|
|
@ -48,13 +48,17 @@ class SoftmaxMKLDNNHandler
|
|
|
|
|
const mkldnn::engine mkldnn_engine,
|
|
|
|
|
platform::Place cpu_place, const Tensor* input,
|
|
|
|
|
Tensor* output, const int axis,
|
|
|
|
|
const std::string uniq_name)
|
|
|
|
|
const std::string uniq_name, bool is_inplaced)
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
|
|
|
|
|
mkldnn::softmax_backward>(
|
|
|
|
|
dev_ctx, mkldnn_engine, cpu_place,
|
|
|
|
|
// Softmax may be inplace then uniq_name is no longer unique
|
|
|
|
|
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
|
|
|
|
|
axis, uniq_name)) {
|
|
|
|
|
is_inplaced ? platform::CreateKey(
|
|
|
|
|
dev_ctx, framework::vectorize(input->dims()),
|
|
|
|
|
axis, uniq_name)
|
|
|
|
|
: platform::CreateKey(
|
|
|
|
|
dev_ctx, framework::vectorize(input->dims()),
|
|
|
|
|
uniq_name)) {
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input->dims(), output->dims(),
|
|
|
|
|
@ -78,7 +82,7 @@ class SoftmaxMKLDNNHandler
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
|
|
|
|
|
mkldnn::softmax_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dev_ctx, dims, axis, uniq_name)) {
|
|
|
|
|
platform::CreateKey(dev_ctx, dims, uniq_name)) {
|
|
|
|
|
auto data_softmax_md =
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
auto diff_softmax_md =
|
|
|
|
|
@ -98,17 +102,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
const Tensor* input = ctx.Input<Tensor>("X");
|
|
|
|
|
Tensor* output = ctx.Output<Tensor>("Out");
|
|
|
|
|
bool is_inplaced = input->IsSharedBufferWith(*output);
|
|
|
|
|
|
|
|
|
|
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
|
|
|
|
|
|
|
|
|
|
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
|
|
|
|
|
input, output, axis, ctx.OutputName("Out"));
|
|
|
|
|
input, output, axis, ctx.OutputName("Out"),
|
|
|
|
|
is_inplaced);
|
|
|
|
|
|
|
|
|
|
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
|
|
|
|
|
// For Inplace src and and dst are the same memory object
|
|
|
|
|
auto softmax_dst_memory_p = input->IsSharedBufferWith(*output)
|
|
|
|
|
? softmax_src_memory_p
|
|
|
|
|
: handler.AcquireDstMemory(output);
|
|
|
|
|
auto softmax_dst_memory_p =
|
|
|
|
|
is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output);
|
|
|
|
|
|
|
|
|
|
auto softmax_p = handler.AcquireForwardPrimitive();
|
|
|
|
|
|
|
|
|
|
|