|
|
|
@ -371,6 +371,13 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
|
|
|
|
|
: platform::MKLDNNHandlerT<T, dnnl::binary>(
|
|
|
|
|
dev_ctx, engine, cpu_place,
|
|
|
|
|
platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {
|
|
|
|
|
// bradcasting combined with in-place may require longer key
|
|
|
|
|
auto rankdiff = x->dims().size() - y->dims().size();
|
|
|
|
|
if (rankdiff > 0) {
|
|
|
|
|
this->key_ += std::to_string(rankdiff);
|
|
|
|
|
this->key_common_ += std::to_string(rankdiff);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x->layout(), DataLayout::kMKLDNN,
|
|
|
|
@ -390,17 +397,19 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
|
|
|
|
|
const auto src_y_tz = framework::vectorize(y->dims());
|
|
|
|
|
const auto dst_tz = framework::vectorize(z->dims());
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Add function checking if data already exists
|
|
|
|
|
const auto src0_md = dnnl::memory::desc(
|
|
|
|
|
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
|
|
|
|
|
const auto src1_md = dnnl::memory::desc(
|
|
|
|
|
auto src1_md = dnnl::memory::desc(
|
|
|
|
|
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
|
|
|
|
|
if (rankdiff > 0) {
|
|
|
|
|
std::vector<int64_t> ones(rankdiff, 1);
|
|
|
|
|
std::vector<int64_t> dims1_ex(src_y_tz);
|
|
|
|
|
dims1_ex.insert(dims1_ex.begin(), ones.begin(), ones.end());
|
|
|
|
|
src1_md = src1_md.reshape(dims1_ex);
|
|
|
|
|
}
|
|
|
|
|
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
|
|
|
|
|
// TODO(jczaja): Binary primitive support broadcasting, so we can support
|
|
|
|
|
// this in kernel
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(dnnl::algorithm::binary_add,
|
|
|
|
|
src0_md, src1_md, dst_md);
|
|
|
|
|
}
|
|
|
|
@ -410,7 +419,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
|
|
|
|
|
const framework::Tensor* input) {
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
|
|
|
|
|
this->fwd_pd_->src1_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|