Merge pull request #14973 from xiaolil1/dequantize
Add Dequantize OPrevert-15207-remove_op_handle_lock_and_fix_var
commit
8a6ac4dba7
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "mkldnn.hpp"
|
||||||
|
#include "paddle/fluid/framework/data_layout_transform.h"
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/operators/dequantize_op.h"
|
||||||
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using mkldnn::memory;
|
||||||
|
using mkldnn::primitive;
|
||||||
|
using mkldnn::reorder;
|
||||||
|
using platform::to_void_cast;
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
using framework::DataLayout;
|
||||||
|
using mkldnn::stream;
|
||||||
|
using platform::GetMKLDNNFormat;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class DeQuantOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* input = ctx.Input<Tensor>("Input");
|
||||||
|
auto scale_data = ctx.Attr<float>("Scale");
|
||||||
|
auto* output = ctx.Output<Tensor>("Output");
|
||||||
|
auto& dev_ctx =
|
||||||
|
ctx.template device_context<platform::MKLDNNDeviceContext>();
|
||||||
|
const auto& engine = dev_ctx.GetEngine();
|
||||||
|
|
||||||
|
const T* input_data = input->data<T>();
|
||||||
|
float* output_data = output->mutable_data<float>(ctx.GetPlace());
|
||||||
|
std::vector<float> reorder_scale = {1.0f / scale_data};
|
||||||
|
|
||||||
|
std::vector<primitive> pipeline;
|
||||||
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
||||||
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
||||||
|
mkldnn::memory::data_type src_dt =
|
||||||
|
paddle::framework::ToMKLDNNDataType(input->type());
|
||||||
|
mkldnn::memory::format src_fmt = input->format();
|
||||||
|
|
||||||
|
mkldnn::primitive_attr attri;
|
||||||
|
int mask = 0;
|
||||||
|
attri.set_output_scales(mask, reorder_scale);
|
||||||
|
|
||||||
|
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
|
||||||
|
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine);
|
||||||
|
auto src_memory =
|
||||||
|
std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
|
||||||
|
std::shared_ptr<primitive::at> src_memory_p =
|
||||||
|
std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
|
||||||
|
|
||||||
|
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
|
||||||
|
memory::format::nchw);
|
||||||
|
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
|
||||||
|
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<float>(output_data));
|
||||||
|
|
||||||
|
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
|
||||||
|
new reorder::primitive_desc(src_pd, dst_pd, attri));
|
||||||
|
auto reorder_p = std::shared_ptr<reorder>(
|
||||||
|
new reorder(*reorder_pd, *src_memory_p, dst_memory));
|
||||||
|
pipeline.push_back(*reorder_p);
|
||||||
|
stream(stream::kind::eager).submit(pipeline).wait();
|
||||||
|
|
||||||
|
output->set_format(GetMKLDNNFormat(dst_memory));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace,
|
||||||
|
ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>);
|
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/dequantize_op.h"
|
||||||
|
#ifdef PADDLE_WITH_MKLDNN
|
||||||
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
framework::OpKernelType DeQuantOp::GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const {
|
||||||
|
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
|
||||||
|
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
|
||||||
|
|
||||||
|
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
|
||||||
|
ctx.GetPlace(), layout_, library_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeQuantOpMaker::Make() {
|
||||||
|
AddInput("Input", "input data");
|
||||||
|
AddOutput("Output", "output data");
|
||||||
|
AddAttr<float>("Scale", "scale data").SetDefault({1.0f});
|
||||||
|
AddComment(R"DOC(This op will dequantize data from INT8 to FP32)DOC");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker,
|
||||||
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::OpKernelType;
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
class DeQuantOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
ctx->SetOutputDim("Output", ctx->GetInputDim("Input"));
|
||||||
|
ctx->ShareLoD("Input", /*->*/ "Output");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DeQuantOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DeQuantGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeQuantizeOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = 'dequantize'
|
||||||
|
self.scale = 2.0
|
||||||
|
self.input_size = [1, 1, 5, 5] #Naive nChw16c
|
||||||
|
self.data_type = 'int8'
|
||||||
|
self.set_scale()
|
||||||
|
self.set_data_type()
|
||||||
|
|
||||||
|
if self.data_type == 'int8':
|
||||||
|
input = (np.random.randint(0, 100, self.input_size) - 50
|
||||||
|
).astype(self.data_type)
|
||||||
|
output = (input * (1 / self.scale)).astype('float')
|
||||||
|
else:
|
||||||
|
input = (np.random.randint(0, 100,
|
||||||
|
self.input_size)).astype(self.data_type)
|
||||||
|
output = (input * (1 / self.scale)).astype('float')
|
||||||
|
|
||||||
|
self.inputs = {'Input': OpTest.np_dtype_to_fluid_dtype(input)}
|
||||||
|
|
||||||
|
self.outputs = {'Output': output}
|
||||||
|
|
||||||
|
self.attrs = {'Scale': self.scale, }
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def set_scale(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_data_type(OpTest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeQuantizeOp1(TestDeQuantizeOp):
|
||||||
|
def set_scale(self):
|
||||||
|
self.scale = 1.5
|
||||||
|
|
||||||
|
def set_data_type(self):
|
||||||
|
self.data_type = 'int8'
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeQuantizeOp2(TestDeQuantizeOp):
|
||||||
|
def set_scale(self):
|
||||||
|
self.scale = 0.8
|
||||||
|
|
||||||
|
def set_data_type(self):
|
||||||
|
self.data_type = 'uint8'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue