|
|
|
@ -25,7 +25,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/sum_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_reuse.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -51,6 +51,95 @@ using paddle::platform::CPUDeviceContext;
|
|
|
|
|
using paddle::platform::MKLDNNDeviceContext;
|
|
|
|
|
using platform::to_void_cast;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
|
|
|
|
|
public:
|
|
|
|
|
SumMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::vector<framework::Variable*>& in_vars,
|
|
|
|
|
framework::LoDTensor* z, const std::string& uniq_name)
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandlerT<T, dnnl::sum>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(framework::vectorize(z->dims()), uniq_name)),
|
|
|
|
|
num_inputs_(0) {
|
|
|
|
|
for (size_t i = 0; i < in_vars.size(); i++) {
|
|
|
|
|
srcs_suffix_.push_back(std::string("-") + std::to_string(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
auto dst_tz = framework::vectorize<int64_t>(z->dims());
|
|
|
|
|
auto src_tz = dst_tz;
|
|
|
|
|
|
|
|
|
|
std::vector<memory::desc> srcs_md;
|
|
|
|
|
for (size_t i = 0; i < in_vars.size(); i++) {
|
|
|
|
|
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
|
|
|
|
|
if (input_it.numel() == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MKLDNNMemoryFormat input_format = input_it.format();
|
|
|
|
|
srcs_md.push_back(memory::desc(src_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
input_format));
|
|
|
|
|
++num_inputs_;
|
|
|
|
|
}
|
|
|
|
|
std::vector<float> scales(num_inputs_, 1.0);
|
|
|
|
|
|
|
|
|
|
auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// (jczaja) sum oneDNN prim is not having .desc attribute so
|
|
|
|
|
// we cannot use base AcquireForwardPrimitiveDescriptor
|
|
|
|
|
void AcquireForwardPrimitiveDescriptor(
|
|
|
|
|
const memory::desc& dst_md, const std::vector<float>& scales,
|
|
|
|
|
const std::vector<memory::desc>& srcs_md) {
|
|
|
|
|
// Sum op does not have backward so no passing from FWD to BWD is needed
|
|
|
|
|
const std::string key_pd = this->key_ + "@fwd_pd";
|
|
|
|
|
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
|
|
|
|
|
this->dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
if (this->fwd_pd_ == nullptr) {
|
|
|
|
|
this->fwd_pd_.reset(new mkldnn::sum::primitive_desc(
|
|
|
|
|
dst_md, scales, srcs_md, this->engine_));
|
|
|
|
|
this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
|
const framework::Tensor& input, int i) {
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
|
|
|
|
|
to_void_cast<T>(input_data),
|
|
|
|
|
"@src_mem_p" + srcs_suffix_[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using platform::MKLDNNHandlerT<T, dnnl::sum>::AcquireDstMemory;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(),
|
|
|
|
|
"@dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline int GetNumInputs(void) { return num_inputs_; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// isCached need to be overloaded as base one works on key_common
|
|
|
|
|
bool isCached() {
|
|
|
|
|
const std::string key_pd = this->key_ + "@fwd_pd";
|
|
|
|
|
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
|
|
|
|
|
this->dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
|
|
|
|
|
const std::string key_p = this->key_ + "@fwd_p";
|
|
|
|
|
return (this->dev_ctx_.GetBlob(key_p) != nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int num_inputs_;
|
|
|
|
|
std::vector<std::string> srcs_suffix_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -59,85 +148,67 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
paddle::platform::errors::PreconditionNotMet(
|
|
|
|
|
"Operator DNNL Sum must use CPUPlace"));
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
auto in_vars = ctx.MultiInputVar("X");
|
|
|
|
|
auto out_var = ctx.OutputVar("Out");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
|
|
|
|
|
"Input variable is empty."));
|
|
|
|
|
bool in_place = out_var == in_vars[0];
|
|
|
|
|
|
|
|
|
|
auto& input0 = in_vars[0]->Get<LoDTensor>();
|
|
|
|
|
LoDTensor* output = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dst_tz = framework::vectorize<int64_t>(output->dims());
|
|
|
|
|
auto src_tz = dst_tz;
|
|
|
|
|
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::undef};
|
|
|
|
|
std::vector<float> scales;
|
|
|
|
|
std::vector<memory::desc> srcs_md;
|
|
|
|
|
std::vector<mkldnn::memory> srcs_mem;
|
|
|
|
|
bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
|
|
|
|
|
|
|
|
|
|
auto& input0 = in_vars[0]->Get<LoDTensor>();
|
|
|
|
|
in_place = (input0.numel() > 0) && (input0.data<T>() == output_data);
|
|
|
|
|
SumMKLDNNHandler<T> handler(dev_ctx, ctx.GetPlace(), in_vars, output,
|
|
|
|
|
ctx.OutputName("Out"));
|
|
|
|
|
|
|
|
|
|
// Create list of SRC MEMs
|
|
|
|
|
std::vector<std::shared_ptr<mkldnn::memory>> srcs_mem;
|
|
|
|
|
srcs_mem.reserve(handler.GetNumInputs());
|
|
|
|
|
int input_index = 0;
|
|
|
|
|
for (size_t i = 0; i < in_vars.size(); i++) {
|
|
|
|
|
auto& input_it = in_vars[i]->Get<LoDTensor>();
|
|
|
|
|
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
|
|
|
|
|
if (input_it.numel() == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* input_data = input_it.data<T>();
|
|
|
|
|
MKLDNNMemoryFormat input_format = input_it.format();
|
|
|
|
|
|
|
|
|
|
auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format);
|
|
|
|
|
auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data));
|
|
|
|
|
srcs_md.push_back(src_md);
|
|
|
|
|
srcs_mem.push_back(src_mem);
|
|
|
|
|
scales.push_back(1.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dst_md =
|
|
|
|
|
memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_md, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<memory> dst_mem;
|
|
|
|
|
if (in_place) {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine));
|
|
|
|
|
} else {
|
|
|
|
|
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine, output_data));
|
|
|
|
|
srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
|
|
|
|
|
++input_index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sum_prim = mkldnn::sum(sum_pd);
|
|
|
|
|
output_format = platform::GetMKLDNNFormat(sum_pd.dst_desc());
|
|
|
|
|
auto dst_mem = in_place ? handler.AcquireDstMemory()
|
|
|
|
|
: handler.AcquireDstMemory(output);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::reorder> reorder_p;
|
|
|
|
|
std::shared_ptr<memory> target_mem;
|
|
|
|
|
if (in_place) {
|
|
|
|
|
output_format = input0.format();
|
|
|
|
|
target_mem.reset(
|
|
|
|
|
new memory({{src_tz}, memory::data_type::f32, output_format},
|
|
|
|
|
mkldnn_engine, output_data));
|
|
|
|
|
reorder_p = std::make_shared<reorder>(*dst_mem, *target_mem);
|
|
|
|
|
}
|
|
|
|
|
auto sum_p = handler.AcquireForwardPrimitive();
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(mkldnn_engine);
|
|
|
|
|
std::unordered_map<int, memory> args;
|
|
|
|
|
for (size_t i = 0; i < srcs_mem.size(); ++i) {
|
|
|
|
|
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, srcs_mem.at(i)});
|
|
|
|
|
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
|
|
|
|
|
}
|
|
|
|
|
args.insert({MKLDNN_ARG_DST, *dst_mem});
|
|
|
|
|
|
|
|
|
|
sum_prim.execute(astream, args);
|
|
|
|
|
mkldnn::stream astream(dev_ctx.GetEngine());
|
|
|
|
|
sum_p->execute(astream, args);
|
|
|
|
|
astream.wait();
|
|
|
|
|
|
|
|
|
|
// For in-place execution which sum does not have we need to fake it
|
|
|
|
|
// so from oneDNN dst memory we reorder data into input
|
|
|
|
|
if (in_place) {
|
|
|
|
|
const std::string reorder_key = platform::CreateKey(
|
|
|
|
|
framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I");
|
|
|
|
|
|
|
|
|
|
auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
|
|
|
|
|
auto output_tz = framework::vectorize<int64_t>(output->dims());
|
|
|
|
|
platform::ReorderMKLDNNHandler reorder_handler(
|
|
|
|
|
output_tz, output->type(), framework::ToMKLDNNDataType(in_out.type()),
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), reorder_key);
|
|
|
|
|
|
|
|
|
|
auto target_mem = reorder_handler.AcquireDstMemory(
|
|
|
|
|
output, in_out.format(), ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem);
|
|
|
|
|
reorder_p->execute(astream, *dst_mem, *target_mem);
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
output->set_format(output_format);
|
|
|
|
|
output->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
|
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|