|
|
@ -21,14 +21,10 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
using mkldnn::memory;
|
|
|
|
using dnnl::memory;
|
|
|
|
using mkldnn::primitive;
|
|
|
|
using dnnl::reorder;
|
|
|
|
using mkldnn::reorder;
|
|
|
|
|
|
|
|
using platform::to_void_cast;
|
|
|
|
using platform::to_void_cast;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using framework::DataLayout;
|
|
|
|
|
|
|
|
using mkldnn::stream;
|
|
|
|
|
|
|
|
using platform::GetMKLDNNFormat;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class ReQuantOpKernel : public framework::OpKernel<T> {
|
|
|
|
class ReQuantOpKernel : public framework::OpKernel<T> {
|
|
|
@ -42,42 +38,66 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
|
|
|
|
ctx.template device_context<platform::MKLDNNDeviceContext>();
|
|
|
|
ctx.template device_context<platform::MKLDNNDeviceContext>();
|
|
|
|
const auto& engine = dev_ctx.GetEngine();
|
|
|
|
const auto& engine = dev_ctx.GetEngine();
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
auto src_tz = paddle::framework::vectorize(input->dims());
|
|
|
|
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
|
|
|
|
|
|
|
|
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
|
|
|
|
std::string key = platform::CreateKey(src_tz, scale_in, scale_out,
|
|
|
|
mkldnn::memory::data_type src_dt =
|
|
|
|
ctx.OutputName("Output"));
|
|
|
|
paddle::framework::ToMKLDNNDataType(input->type());
|
|
|
|
const std::string key_prim = key + "@reorder_p";
|
|
|
|
mkldnn::memory::data_type dst_dt = src_dt;
|
|
|
|
const std::string key_src_mem = key + "@src_mem";
|
|
|
|
MKLDNNMemoryFormat src_fmt = MKLDNNMemoryFormat::nhwc;
|
|
|
|
const std::string key_dst_mem = key + "@dst_mem";
|
|
|
|
MKLDNNMemoryFormat dst_fmt = MKLDNNMemoryFormat::nhwc;
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<dnnl::memory> src_memory;
|
|
|
|
|
|
|
|
std::shared_ptr<dnnl::memory> dst_memory;
|
|
|
|
|
|
|
|
std::shared_ptr<reorder> reorder_p;
|
|
|
|
|
|
|
|
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
|
|
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
float scale_shift = scale_out / scale_in;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mkldnn::primitive_attr attri;
|
|
|
|
if (reorder_p == nullptr) {
|
|
|
|
|
|
|
|
dnnl::primitive_attr attri;
|
|
|
|
int mask = 0;
|
|
|
|
int mask = 0;
|
|
|
|
|
|
|
|
float scale_shift = scale_out / scale_in;
|
|
|
|
attri.set_output_scales(mask, {scale_shift});
|
|
|
|
attri.set_output_scales(mask, {scale_shift});
|
|
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
|
|
|
|
auto dst_tz = paddle::framework::vectorize(output->dims());
|
|
|
|
auto src_memory = std::make_shared<mkldnn::memory>(
|
|
|
|
dnnl::memory::data_type src_dt =
|
|
|
|
src_md, engine, to_void_cast<T>(input_data));
|
|
|
|
paddle::framework::ToMKLDNNDataType(input->type());
|
|
|
|
|
|
|
|
dnnl::memory::data_type dst_dt = src_dt;
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt);
|
|
|
|
|
|
|
|
auto dst_memory =
|
|
|
|
auto src_md =
|
|
|
|
mkldnn::memory(dst_md, engine, to_void_cast<T>(output_data));
|
|
|
|
platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc);
|
|
|
|
|
|
|
|
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
|
|
|
|
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
|
|
|
|
to_void_cast<T>(input_data));
|
|
|
|
new reorder::primitive_desc(*src_memory, dst_memory, attri));
|
|
|
|
|
|
|
|
|
|
|
|
auto dst_md =
|
|
|
|
auto reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
|
|
|
|
platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc);
|
|
|
|
|
|
|
|
dst_memory = std::make_shared<dnnl::memory>(dst_md, engine,
|
|
|
|
|
|
|
|
to_void_cast<T>(output_data));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto reorder_pd =
|
|
|
|
|
|
|
|
reorder::primitive_desc(*src_memory, *dst_memory, attri);
|
|
|
|
|
|
|
|
reorder_p = std::make_shared<reorder>(reorder_pd);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_prim, reorder_p);
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_src_mem, src_memory);
|
|
|
|
|
|
|
|
dev_ctx.SetBlob(key_dst_mem, dst_memory);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
src_memory =
|
|
|
|
|
|
|
|
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
|
|
|
|
|
|
|
|
src_memory->set_data_handle(to_void_cast<T>(input_data));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dst_memory =
|
|
|
|
|
|
|
|
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
|
|
|
|
|
|
|
|
dst_memory->set_data_handle(output_data);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(engine);
|
|
|
|
dnnl::stream astream(engine);
|
|
|
|
reorder_p->execute(astream, *src_memory, dst_memory);
|
|
|
|
reorder_p->execute(astream, *src_memory, *dst_memory);
|
|
|
|
astream.wait();
|
|
|
|
astream.wait();
|
|
|
|
|
|
|
|
|
|
|
|
output->set_layout(DataLayout::kMKLDNN);
|
|
|
|
output->set_layout(framework::DataLayout::kMKLDNN);
|
|
|
|
output->set_format(GetMKLDNNFormat(dst_memory));
|
|
|
|
output->set_format(platform::GetMKLDNNFormat(*dst_memory));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|