|
|
|
@ -27,6 +27,8 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
using framework::DataLayout;
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
using user_function = std::function<std::shared_ptr<float>(const float*)>;
|
|
|
|
|
using memory = mkldnn::memory;
|
|
|
|
|
|
|
|
|
@ -108,6 +110,13 @@ class MKLDNNHandlerT {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
bool isCached() {
|
|
|
|
|
const std::string key_pd = key_common_ + "@forward_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
return (fwd_pd_ != nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
void AcquireForwardPrimitiveDescriptor(Args&&... args) {
|
|
|
|
|
// Forward PD has to be passed to Grad op that
|
|
|
|
@ -355,22 +364,46 @@ class MKLDNNHandler {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
|
|
|
|
|
public:
|
|
|
|
|
BinaryMKLDNNHandler(const dnnl::algorithm algo,
|
|
|
|
|
const std::vector<int64_t>& dims,
|
|
|
|
|
const MKLDNNMemoryFormat src0_fmt,
|
|
|
|
|
const MKLDNNMemoryFormat src1_fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place, const std::string& uniq_name)
|
|
|
|
|
BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
const mkldnn::engine engine, platform::Place cpu_place,
|
|
|
|
|
const Tensor* x, const Tensor* y, Tensor* z,
|
|
|
|
|
const std::string uniq_name)
|
|
|
|
|
: platform::MKLDNNHandlerT<T, dnnl::binary>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, uniq_name)) {
|
|
|
|
|
// TODO(jczaja): Add function checking if data already exists
|
|
|
|
|
auto src0_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src0_fmt);
|
|
|
|
|
auto src1_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src1_fmt);
|
|
|
|
|
auto dst_md =
|
|
|
|
|
memory::desc(dims, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(algo, src0_md, src1_md, dst_md);
|
|
|
|
|
dev_ctx, engine, cpu_place,
|
|
|
|
|
platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for X tensor"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for Y tensor"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
y->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for Y tensor"));
|
|
|
|
|
|
|
|
|
|
const auto src_x_tz = framework::vectorize(x->dims());
|
|
|
|
|
const auto src_y_tz = framework::vectorize(y->dims());
|
|
|
|
|
const auto dst_tz = framework::vectorize(z->dims());
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Add function checking if data already exists
|
|
|
|
|
const auto src0_md = dnnl::memory::desc(
|
|
|
|
|
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
|
|
|
|
|
const auto src1_md = dnnl::memory::desc(
|
|
|
|
|
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
|
|
|
|
|
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
|
|
|
|
|
// TODO(jczaja): Binary primitive support broadcasting, so we can support
|
|
|
|
|
// this in kernel
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(dnnl::algorithm::binary_add,
|
|
|
|
|
src0_md, src1_md, dst_md);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
|
|
|
|
|