|
|
|
@ -33,27 +33,45 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
|
|
|
|
|
ElemwiseGradKernel<T>::Compute(ctx);
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
|
|
|
|
|
const auto& onednn_engine = dev_ctx.GetEngine();
|
|
|
|
|
|
|
|
|
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
|
|
|
|
|
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
|
|
|
|
|
in->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
in->set_format(out->format());
|
|
|
|
|
};
|
|
|
|
|
auto tz = paddle::framework::vectorize<int64_t>(dout->dims());
|
|
|
|
|
memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type());
|
|
|
|
|
std::string key = platform::CreateKey(dev_ctx, tz, dout->format(),
|
|
|
|
|
dout->format(), dout_type);
|
|
|
|
|
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx,
|
|
|
|
|
onednn_engine, key);
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(onednn_engine);
|
|
|
|
|
auto reorder_src_memory_p = handler.AcquireSrcMemory(
|
|
|
|
|
dout->format(), platform::to_void_cast(dout->data<T>()));
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Double check if vcopy works for blocked data
|
|
|
|
|
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
|
|
|
|
|
if (dx) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
set_mkldnn_format(dx, dout);
|
|
|
|
|
auto reorder_dst_memory_p =
|
|
|
|
|
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
|
|
|
|
|
auto reorder_p =
|
|
|
|
|
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
|
|
|
|
|
platform::RecordEvent record_reorder("int_reorder",
|
|
|
|
|
platform::EventRole::kUniqueOp);
|
|
|
|
|
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
set_mkldnn_format(dy, dout);
|
|
|
|
|
auto reorder_dst_memory_p =
|
|
|
|
|
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
|
|
|
|
|
auto reorder_p =
|
|
|
|
|
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
|
|
|
|
|
platform::RecordEvent record_reorder("int_reorder",
|
|
|
|
|
platform::EventRole::kUniqueOp);
|
|
|
|
|
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|