|
|
|
|
@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> dst_mem;
|
|
|
|
|
if (in_place)
|
|
|
|
|
if (in_place) {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
|
|
|
|
|
else
|
|
|
|
|
} else {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
std::vector<mkldnn::primitive::at> inputs;
|
|
|
|
|
for (size_t i = 0; i < srcs_mem.size(); ++i) {
|
|
|
|
|
inputs.push_back(srcs_mem[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
|
|
|
|
|
output_format =
|
|
|
|
|
(memory::format)sum_pd.dst_primitive_desc().desc().data.format;
|
|
|
|
|
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
|
|
|
|
|
|
|
|
|
|
primitive reorder_prim;
|
|
|
|
|
std::shared_ptr<memory> target_mem;
|
|
|
|
|
|