Develop a fake dequantized op for fixed-point quantization training framework. (#10965)
* Develop a fake dequantized op for fixed-point quantization training framework. * Add the missing file.release/0.13.0
parent
66ec827a92
commit
3a29821bd5
@ -0,0 +1,76 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fake_dequantize_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
FakeDequantizeMaxAbsOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor) The input with float-32/64 type is the "
|
||||
"low precision tensor.");
|
||||
AddOutput("Out",
|
||||
"(Tensor) The output is the dequantized high "
|
||||
"precision tensor.");
|
||||
AddAttr<int>("num_bits",
|
||||
"(int) `num_bits` is the quantization level bits, "
|
||||
"such as 2, 5, 8.");
|
||||
AddAttr<float>("scale",
|
||||
"(float) The maximum absolute value of low precision tensor."
|
||||
"It is usually calculated by the fake_quantize_max_abs_op.");
|
||||
AddComment(R"DOC(
|
||||
FakeDequantizeMaxAbsOp operator.
|
||||
|
||||
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
|
||||
|
||||
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
|
||||
ops::FakeDequantizeMaxAbsOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
|
||||
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
|
||||
ops::FakeDequantizeMaxAbsKernel<CPU, double>);
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fake_dequantize_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
|
||||
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
|
||||
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
template <typename DeviceContext, typename T>
|
||||
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
out->mutable_data<T>(in->place());
|
||||
|
||||
int num_bits = ctx.Attr<int>("num_bits");
|
||||
T scale = static_cast<T>(ctx.Attr<float>("scale"));
|
||||
int range = std::pow(2, num_bits) - 1;
|
||||
|
||||
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
|
||||
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
eigen_out.device(dev) = (scale / range) * eigen_in;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def quantize_max_abs(x, num_bits):
|
||||
range = math.pow(2, num_bits) - 1
|
||||
scale = np.max(np.abs(x).flatten())
|
||||
y = np.round(x / scale * range)
|
||||
return y, scale
|
||||
|
||||
|
||||
def dequantize_max_abs(x, num_bits, scale):
|
||||
range = math.pow(2, num_bits) - 1
|
||||
y = (scale / range) * x
|
||||
return y
|
||||
|
||||
|
||||
class TestFakeDequantizeMaxAbsOp(OpTest):
|
||||
def set_args(self):
|
||||
self.num_bits = 8
|
||||
|
||||
def setUp(self):
|
||||
self.set_args()
|
||||
self.op_type = "fake_dequantize_max_abs"
|
||||
x = np.random.randn(31, 65).astype("float32")
|
||||
yq, scale = quantize_max_abs(x, self.num_bits)
|
||||
print 'scale ', scale
|
||||
ydq = dequantize_max_abs(yq, self.num_bits, scale)
|
||||
|
||||
self.inputs = {'X': yq}
|
||||
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
|
||||
self.outputs = {'Out': ydq}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
|
||||
def set_args(self):
|
||||
self.num_bits = 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue