|
|
|
@ -54,102 +54,84 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
auto in_vars = ctx.MultiInputVar("X");
|
|
|
|
|
|
|
|
|
|
const int N = in_vars.size();
|
|
|
|
|
auto out_var = ctx.OutputVar("Out");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
|
|
|
|
|
"Input variable is empty."));
|
|
|
|
|
bool in_place = out_var == in_vars[0];
|
|
|
|
|
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
LoDTensor* output = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dst_tz = framework::vectorize<int64_t>(output->dims());
|
|
|
|
|
auto src_tz = dst_tz;
|
|
|
|
|
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::undef};
|
|
|
|
|
std::vector<float> scales;
|
|
|
|
|
std::vector<memory::desc> srcs_md;
|
|
|
|
|
std::vector<mkldnn::memory> srcs_mem;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_vars[0]->IsType<LoDTensor>(), true,
|
|
|
|
|
"Input[0] must be LoDTensors");
|
|
|
|
|
auto& input0 = in_vars[0]->Get<LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(input0.layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for inputs[0] tensor");
|
|
|
|
|
PADDLE_ENFORCE_NE(input0.format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for inputs[0] tensor");
|
|
|
|
|
|
|
|
|
|
MKLDNNMemoryFormat input_format = input0.format();
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_vars[i]->IsType<LoDTensor>(), true,
|
|
|
|
|
"all inputs must be all LoDTensors");
|
|
|
|
|
auto& input = in_vars[i]->Get<LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(input.layout(), DataLayout::kMKLDNN,
|
|
|
|
|
"Wrong layout set for inputs");
|
|
|
|
|
PADDLE_ENFORCE_NE(input.format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
"Wrong format set for inputs");
|
|
|
|
|
|
|
|
|
|
if (input.numel() == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
|
|
|
|
|
|
auto src_md =
|
|
|
|
|
memory::desc(src_tz, memory::data_type::f32, input_format);
|
|
|
|
|
auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data));
|
|
|
|
|
srcs_md.push_back(src_md);
|
|
|
|
|
srcs_mem.push_back(src_mem);
|
|
|
|
|
scales.push_back(1.0);
|
|
|
|
|
}
|
|
|
|
|
LoDTensor* output = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dst_md =
|
|
|
|
|
memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any);
|
|
|
|
|
auto dst_tz = framework::vectorize<int64_t>(output->dims());
|
|
|
|
|
auto src_tz = dst_tz;
|
|
|
|
|
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::undef};
|
|
|
|
|
std::vector<float> scales;
|
|
|
|
|
std::vector<memory::desc> srcs_md;
|
|
|
|
|
std::vector<mkldnn::memory> srcs_mem;
|
|
|
|
|
|
|
|
|
|
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_md, mkldnn_engine);
|
|
|
|
|
auto& input0 = in_vars[0]->Get<LoDTensor>();
|
|
|
|
|
in_place = (input0.numel() > 0) && (input0.data<T>() == output_data);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> dst_mem;
|
|
|
|
|
if (in_place) {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine));
|
|
|
|
|
} else {
|
|
|
|
|
dst_mem.reset(
|
|
|
|
|
new memory(sum_pd.dst_desc(), mkldnn_engine, output_data));
|
|
|
|
|
}
|
|
|
|
|
MKLDNNMemoryFormat input_format = input0.format();
|
|
|
|
|
|
|
|
|
|
auto sum_prim = mkldnn::sum(sum_pd);
|
|
|
|
|
output_format = platform::GetMKLDNNFormat(sum_pd.dst_desc());
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::reorder> reorder_p;
|
|
|
|
|
std::shared_ptr<memory> target_mem;
|
|
|
|
|
if (in_place) {
|
|
|
|
|
output_format = input_format;
|
|
|
|
|
target_mem.reset(
|
|
|
|
|
new memory({{src_tz}, memory::data_type::f32, output_format},
|
|
|
|
|
mkldnn_engine, output_data));
|
|
|
|
|
reorder_p = std::make_shared<reorder>(*dst_mem, *target_mem);
|
|
|
|
|
for (size_t i = 0; i < in_vars.size(); i++) {
|
|
|
|
|
auto& input_it = in_vars[i]->Get<LoDTensor>();
|
|
|
|
|
if (input_it.numel() == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(mkldnn_engine);
|
|
|
|
|
std::unordered_map<int, memory> args;
|
|
|
|
|
for (size_t i = 0; i < srcs_mem.size(); ++i) {
|
|
|
|
|
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, srcs_mem.at(i)});
|
|
|
|
|
}
|
|
|
|
|
args.insert({MKLDNN_ARG_DST, *dst_mem});
|
|
|
|
|
const T* input_data = input_it.data<T>();
|
|
|
|
|
|
|
|
|
|
sum_prim.execute(astream, args);
|
|
|
|
|
astream.wait();
|
|
|
|
|
auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format);
|
|
|
|
|
auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data));
|
|
|
|
|
srcs_md.push_back(src_md);
|
|
|
|
|
srcs_mem.push_back(src_mem);
|
|
|
|
|
scales.push_back(1.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (in_place) {
|
|
|
|
|
reorder_p->execute(astream, *dst_mem, *target_mem);
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
auto dst_md =
|
|
|
|
|
memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_md, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
output->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
output->set_format(output_format);
|
|
|
|
|
} else { // Fallback to naive version
|
|
|
|
|
SumKernel<CPUDeviceContext, T> reference_kernel;
|
|
|
|
|
reference_kernel.Compute(ctx);
|
|
|
|
|
std::shared_ptr<memory> dst_mem;
|
|
|
|
|
if (in_place) {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine));
|
|
|
|
|
} else {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine, output_data));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sum_prim = mkldnn::sum(sum_pd);
|
|
|
|
|
output_format = platform::GetMKLDNNFormat(sum_pd.dst_desc());
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::reorder> reorder_p;
|
|
|
|
|
std::shared_ptr<memory> target_mem;
|
|
|
|
|
if (in_place) {
|
|
|
|
|
output_format = input_format;
|
|
|
|
|
target_mem.reset(
|
|
|
|
|
new memory({{src_tz}, memory::data_type::f32, output_format},
|
|
|
|
|
mkldnn_engine, output_data));
|
|
|
|
|
reorder_p = std::make_shared<reorder>(*dst_mem, *target_mem);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(mkldnn_engine);
|
|
|
|
|
std::unordered_map<int, memory> args;
|
|
|
|
|
for (size_t i = 0; i < srcs_mem.size(); ++i) {
|
|
|
|
|
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, srcs_mem.at(i)});
|
|
|
|
|
}
|
|
|
|
|
args.insert({MKLDNN_ARG_DST, *dst_mem});
|
|
|
|
|
|
|
|
|
|
sum_prim.execute(astream, args);
|
|
|
|
|
astream.wait();
|
|
|
|
|
|
|
|
|
|
if (in_place) {
|
|
|
|
|
reorder_p->execute(astream, *dst_mem, *target_mem);
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
output->set_format(output_format);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|