|
|
|
@ -18,15 +18,32 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& dev_ctx,
|
|
|
|
|
const framework::Tensor* in, const framework::Tensor* scale,
|
|
|
|
|
T max_range, framework::Tensor* out) {
|
|
|
|
|
auto in_e = framework::EigenVector<T>::Flatten(*in);
|
|
|
|
|
const T* scale_factor = scale->data<T>();
|
|
|
|
|
auto out_e = framework::EigenVector<T>::Flatten(*out);
|
|
|
|
|
|
|
|
|
|
auto& dev = *dev_ctx.eigen_device();
|
|
|
|
|
out_e.device(dev) = (scale_factor[0] / max_range) * in_e;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
|
|
|
|
|
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
|
|
|
|
|
|
|
|
|
|
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
FakeDequantizeMaxAbsOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
FakeDequantizeMaxAbsOp(const std::string& type,
|
|
|
|
|
const framework::VariableNameMap& inputs,
|
|
|
|
|
const framework::VariableNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor) The input with float-32/64 type is the "
|
|
|
|
|
"low precision tensor.");
|
|
|
|
|
AddInput("Scale", "(float) The scale in quantization stage.");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) The output is the dequantized high "
|
|
|
|
|
"precision tensor.");
|
|
|
|
|
AddAttr<int>("num_bits",
|
|
|
|
|
"(int) `num_bits` is the quantization level bits, "
|
|
|
|
|
"such as 2, 5, 8.");
|
|
|
|
|
AddAttr<float>("scale",
|
|
|
|
|
"(float) The maximum absolute value of low precision tensor."
|
|
|
|
|
"It is usually calculated by the fake_quantize_max_abs_op.");
|
|
|
|
|
AddAttr<float>("max_range", "(float) The max range in quantization stage.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
FakeDequantizeMaxAbsOp operator.
|
|
|
|
|
|
|
|
|
|
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
|
|
|
|
|
|
|
|
|
|
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
|
|
|
|
|
$$Out = \frac{scale*X}{ max_range }$$
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|