Quant op dev (#25932)

* Finished ChannelWiseQuantDequantAbsMaxOp and Passed unittests.

* Finished channel-wise quantize strategy in imperative quantization.

* Added Cuda code of ChannelWiseQuantDequantMaxAbsOP
Add Cuda code of ChannelWiseQuantDequantMaxAbsOp

* Add quant_axis for channel_wise quant.

* fixed a bug in unnitests, which will not trigger axis = 1 case and cannot meet the coverage rate requirement.

* Added some assert infomation and fixed some coding style mistakes.
revert-27520-disable_pr
huangxu96 4 years ago committed by GitHub
parent aa7835efee
commit 02606d45ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -174,7 +174,64 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>;
template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans;
if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel;
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
}
} else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
for (int i = 0; i < in_dims[0]; i++) {
for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j];
T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) *
s / static_cast<T>(bin_cnt);
}
}
}
}
}
};
template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext,
float>;
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
@ -360,6 +417,75 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
}
};
class FakeChannelWiseQuantizeDequantizeAbsMaxOp
: public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeChannelWiseQuantizeDequantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddOutput("Out",
"(Tensor) Output of quantized and dequantized low level tensor, "
"saved as float data type.");
AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c}) * \frac{scale_c} {range}$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC");
}
};
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeRangeAbsMaxOp(const std::string& type,
@ -666,3 +792,12 @@ REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);

@ -417,8 +417,90 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
}
};
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
float>;
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(
const T* in, const T* scale, const int bin_cnt, const int n, const int c,
T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
}
}
// ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(
const T* in, const T* scale, const int bin_cnt, const int n, const int cin,
const int cout, T* out) {
T s = scale[blockIdx.x % cout];
T inv_s = inverse(s);
int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
}
}
template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
int num = in.numel();
auto in_dims = in.dims();
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis0<
T><<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt,
num, in_dims[0], out_data);
} else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1];
int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis1<
T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
}
}
};
template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
float>;
} // namespace operators
} // namespace paddle
@ -443,3 +525,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);

@ -72,6 +72,13 @@ struct ChannelClipAndFakeQuantFunctor {
const int quant_axis, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int quant_axis, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum,
@ -154,6 +161,30 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis");
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis,
out_scale_data);
ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out);
}
};
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public:

@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"

@ -111,6 +111,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
{"fake_channel_wise_quantize_dequantize_abs_max", {"Out", "OutScale"}},
{"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},

@ -99,7 +99,12 @@ class ImperativeQuantAware(object):
self._activation_bits = activation_bits
self._moving_rate = moving_rate
quant_type = {'abs_max', 'moving_average_abs_max'}
quant_type = {
'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'
}
assert activation_quantize_type != 'channel_wise_abs_max', \
"The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be "
@ -108,8 +113,8 @@ class ImperativeQuantAware(object):
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' now." %
(str(weight_quantize_type)))
"'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now."
% (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type

@ -24,7 +24,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype
__all__ = [
'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D',
'QuantizedLinear'
'QuantizedLinear', 'FakeChannelWiseQuantDequantAbsMax'
]
@ -209,6 +209,89 @@ class FakeQuantAbsMax(layers.Layer):
return quant_out
class FakeChannelWiseQuantDequantAbsMax(layers.Layer):
def __init__(self,
name=None,
channel_num=None,
quant_bits=8,
quant_axis=0,
dtype='float32',
quant_on_weight=False):
assert quant_on_weight == True, "Channel_wise only can be used on weight quantization."
super(FakeChannelWiseQuantDequantAbsMax, self).__init__()
self._quant_bits = quant_bits
self._quant_axis = quant_axis
self._dtype = dtype
self._name = name
self._channel_num = channel_num
scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale'
self._scale_name = unique_name.generate(scale_prefix)
if quant_on_weight:
scale_attr = ParamAttr(
name=self._scale_name,
initializer=Constant(0.0),
trainable=False)
self._scale = self.create_parameter(
shape=[self._channel_num], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
else:
self._scale = None
def forward(self, input):
if in_dygraph_mode():
attrs = ('bit_length', self._quant_bits, 'quant_axis',
self._quant_axis)
quant_out = _varbase_creator(
type=input.type,
name="{}.quantized.dequantized".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
out_scale = self._scale
if out_scale is None:
out_scale = _varbase_creator(
type=core.VarDesc.VarType.LOD_TENSOR,
name=self._scale_name,
shape=[self._channel_num],
dtype=self._dtype,
persistable=False)
out_scale.stop_gradient = True
out, _, = core.ops.fake_channel_wise_quantize_dequantize_abs_max(
input, quant_out, out_scale, *attrs)
return out
check_variable_and_dtype(input, 'input', ['float32'],
"FakeChannelWiseQuantDequantAbsMax")
attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis}
inputs = {"X": [input]}
quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
out_scale = self._scale
if not out_scale:
out_scale = self._helper.create_variable(
name=self._scale_name,
dtype=self._dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
outputs = {"Out": [quant_out], "OutScale": [out_scale]}
self._helper.append_op(
type="fake_channel_wise_quantize_dequantize_abs_max",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return quant_out
def _get_fake_quant_type(quant_type, **kwargs):
call_args = {
"name": kwargs.get("name", None),
@ -220,10 +303,17 @@ def _get_fake_quant_type(quant_type, **kwargs):
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
elif quant_type == 'moving_average_abs_max':
call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
elif quant_type == 'channel_wise_abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
call_args["channel_num"] = kwargs.get("channel_num", None)
call_args["quant_axis"] = kwargs.get("quant_axis", 0)
assert call_args["channel_num"] is not None, (
"You need to input channel_num"
"when you use channel_wise_abs_max strategy.")
fake_quant_map = {
'abs_max': FakeQuantAbsMax,
'moving_average_abs_max': FakeQuantMovingAverage
'moving_average_abs_max': FakeQuantMovingAverage,
'channel_wise_abs_max': FakeChannelWiseQuantDequantAbsMax
}
return fake_quant_map[quant_type](**call_args)
@ -255,19 +345,23 @@ class QuantizedConv2D(layers.Layer):
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._conv2d_quant_axis = 0
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type,
name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True)
quant_on_weight=True,
channel_num=self.weight.shape[self._conv2d_quant_axis],
quant_axis=self._conv2d_quant_axis)
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype)
dtype=self._dtype,
quant_on_weight=False)
def forward(self, input):
quant_input = self._fake_quant_input(input)
@ -341,19 +435,23 @@ class QuantizedLinear(layers.Layer):
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._linear_quant_axis = 1
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type,
name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True)
quant_on_weight=True,
channel_num=self.weight.shape[self._linear_quant_axis],
quant_axis=self._linear_quant_axis)
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype)
dtype=self._dtype,
quant_on_weight=False)
def forward(self, input):
quant_input = self._fake_quant_input(input)

@ -181,7 +181,6 @@ class TestImperativeQat(unittest.TestCase):
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)

@ -306,5 +306,70 @@ class TestFakeQuantDequantAbsOp(OpTest):
self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestChannelWiseFakeQuantDequantOp(OpTest):
def setUp(self):
self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
self.op_type = "fake_channel_wise_quantize_dequantize_abs_max"
self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}
scales = []
outputs = self.inputs['X'].copy()
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
if self.quant_axis == 0:
for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
scales.append(scale_v)
outputs[i] = np.round(outputs[i] * range_v /
scale_v) * scale_v / range_v
elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
"float32")
scales.append(scale_v)
outputs[:, i] = np.round(outputs[:, i] * range_v /
scale_v) * scale_v / range_v
self.outputs = {
'Out': outputs,
'OutScale': np.array(scales).astype("float32"),
}
def set_arg(self):
self.quant_axis = 0
self.inputs = {
'X': np.random.random((3, 4, 64, 64)).astype("float32"),
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 1
self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype("float32"),
}
class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save