MKLDNN layout: the code-review changes

revert-11610-move_hooks
mozga-intel 7 years ago
parent 96b4904d2f
commit 6512be59ec

@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
std::shared_ptr<memory> dst_mem;
if (in_place)
if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
else
} else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
}
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
inputs.push_back(srcs_mem[i]);
}
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
output_format =
(memory::format)sum_pd.dst_primitive_desc().desc().data.format;
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
primitive reorder_prim;
std::shared_ptr<memory> target_mem;

@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) {
memory.get_primitive_desc().desc().data.format);
}
inline mkldnn::memory::format GetMKLDNNFormat(
const mkldnn::sum::primitive_desc& memory) {
return static_cast<mkldnn::memory::format>(
memory.dst_primitive_desc().desc().data.format);
}
} // namespace platform
} // namespace paddle

Loading…
Cancel
Save