Fix (de/re)quantize cache keys (#26549)

test_feature_precision_test_c
Wojciech Uss 5 years ago committed by GitHub
parent eeda90d674
commit 5c2b9258a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,11 +51,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format();
std::string key =
platform::CreateKey(src_dt, src_tz, ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt,
src_tz, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;

@ -48,11 +48,12 @@ class QuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
bool is_negative = ctx.Attr<bool>("is_negative_input");
std::string key = platform::CreateKey(src_tz, scale_data, is_negative,
ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data,
is_negative, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;

@ -40,11 +40,12 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize(input->dims());
std::string key = platform::CreateKey(src_tz, scale_in, scale_out,
ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in,
scale_out, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;

Loading…
Cancel
Save