[Quantization] Conv2d_transpose and mul support channnelwise quantization (#25639)

* Conv2d_transpose and mul support channnelwise quantization, test=develop
* Skip collecting out threshold for output tensor of which the type is not fp32 or fp64, test=develop
* Fix error in test_user_defined_quantization, test=develop
* Add depthwise_conv_bn_fuse, test=develop
* Add conv_transpose_bn_fuse_pass for post_training_quant, test=develop
revert-24895-update_cub
cc 5 years ago committed by GitHub
parent 2101dfd2b3
commit 3f816bc8b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass,
paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);

@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
std::string conv_type() const { return "conv2d_transpose"; }
};
class DepthwiseConvBNFusePass : public ConvBNFusePass {
public:
std::string conv_type() const { return "depthwise_conv2d"; }
};
class DepthwiseConvEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public:
std::string conv_type() const { return "depthwise_conv2d"; }
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -37,11 +37,16 @@ template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
if (scale_num == 1) {
const int channel = in->dims()[0];
// Dequant op is before quantized op
// Dequantize the weight of quantized op
auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis];
const T* scale_factor = scales[0]->data<T>();
for (int i = 0; i < channel; i++) {
if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i];
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
@ -50,7 +55,31 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range;
}
} else if (quant_axis == 1) {
int64_t out_iter = 1;
for (int i = 0; i < quant_axis; i++) {
out_iter *= in_dims[i];
}
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range;
++cur_in;
++cur_out;
}
}
}
}
} else if (scale_num == 2) {
// Dequant op is after quantized op
// Dequantize the output tensor of quantized op
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
@ -157,6 +186,18 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
"Quantization bit numbers in quantization stage. "
"The size of `quant_bits` should be equal to the size of `Scales`.")
.SetDefault({8});
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.

@ -45,8 +45,9 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
};
template <typename T>
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
int num, int channel, T* out) {
__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale,
T max_range, int num, int channel,
T* out) {
int tid = threadIdx.x;
int channel_size = num / channel;
const T* in_c = in + blockIdx.x * channel_size;
@ -56,6 +57,23 @@ __global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
}
}
template <typename T>
__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
T max_range, const int num,
const int cin, const int cout,
T* out) {
int cout_wh_size = num / cin;
int wh_size = cout_wh_size / cout;
T s = scale[blockIdx.x];
const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size;
T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size;
for (int i = 0; i < wh_size; i++) {
out_current[i] = in_current[i] * s / max_range;
}
}
template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num,
@ -74,18 +92,29 @@ template <typename T>
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
auto in_dims = in->dims();
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) {
int num = in->numel();
int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024;
int grid = channel;
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, channel, out_data);
DequantizeOneScaleQuantAxis0<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], out_data);
} else if (quant_axis == 1) {
// Dequantize weight of Cin * Cout * W * H
int grid = in_dims[1];
int block = in_dims[0];
DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], in_dims[1],
out_data);
}
} else if (scale_num == 2) {
// Not need to consider quant_axis
int num = in->numel();
int batch_size = in->dims()[0];
int channel = in->dims()[1];

@ -33,7 +33,7 @@ template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, framework::Tensor* out);
T max_range, const int quant_axis, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
@ -63,6 +63,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::Tensor>("Out");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
auto quant_axis = ctx.Attr<int>("quant_axis");
int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
@ -70,12 +71,12 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
int scale_num = scales.size();
if (scale_num == 1) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[0],
scales[0]->numel(), in->dims()[quant_axis],
platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only "
"one element, but %ld != %ld here.",
scales[0]->numel(), in->dims()[0]));
"quant_axis dimension value of Input(X) when the `Scales` has "
"only one element, but %ld != %ld here.",
scales[0]->numel(), in->dims()[quant_axis]));
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
@ -94,7 +95,8 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
(std::pow(2, quant_bits[1] - 1) - 1);
}
ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out);
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range),
quant_axis, out);
}
};

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
@ -39,13 +40,41 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
const int channel_size = num / channel;
for (int i = 0; i < channel; i++) {
auto* start = in + i * channel_size;
auto* end = in + (i + 1) * channel_size;
out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in_tensor, const int quant_axis,
T* out_abs_max) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* in_data = in_tensor.data<T>();
auto in_dims = in_tensor.dims();
const int64_t channel = in_dims[quant_axis];
if (quant_axis == 0) {
const int64_t channel_size = in_tensor.numel() / channel;
for (int64_t i = 0; i < channel; i++) {
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
out_abs_max[i] =
std::abs(*(std::max_element(start, end, Compare<T>())));
}
} else if (quant_axis == 1) {
for (int64_t i = 0; i < channel; i++) {
out_abs_max[i] = 0;
}
const int64_t step_i = in_tensor.numel() / in_dims[0];
const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]);
for (int64_t i = 0; i < in_dims[0]; i++) {
for (int64_t j = 0; j < in_dims[1]; j++) {
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
T abs_max = std::abs(*(std::max_element(start, end, Compare<T>())));
out_abs_max[j] = std::max(out_abs_max[j], abs_max);
}
}
}
}
};
@ -92,27 +121,54 @@ template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
const int channel_size = in.numel() / channel;
auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans;
for (int i = 0; i < channel; i++) {
if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel;
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i];
T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
} else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
for (int i = 0; i < in_dims[0]; i++) {
for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j];
T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
}
}
}
}
}
};
@ -247,8 +303,9 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
"FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeChannelWiseQuantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
ctx->ShareLoD("X", /*->*/ "Out");
}
@ -269,6 +326,18 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type.");
AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {

@ -75,8 +75,8 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
T* out) {
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
const int c, T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
@ -100,14 +100,69 @@ __global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
}
}
template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const int cin, const int cout,
T* out) {
extern __shared__ T shared_max_data[];
int cout_wh_size = n / cin;
int wh_size = n / (cin * cout);
int tid = threadIdx.x;
int bid = blockIdx.x;
const T* in_current = in + tid * cout_wh_size + bid * wh_size;
shared_max_data[tid] = T(0);
for (int i = 0; i < wh_size; i++) {
T tmp = fabs(in_current[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
}
}
__syncthreads();
int len = blockDim.x;
for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) {
if (tid < i && tid + i < len &&
shared_max_data[tid] < shared_max_data[tid + i]) {
shared_max_data[tid] = shared_max_data[tid + i];
}
if (i == 1) {
i = 0; // break the loop
}
__syncthreads();
}
if (tid == 0) {
out[bid] = shared_max_data[0];
}
}
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
int block = 1024;
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in_tensor, const int quant_axis,
T* out_abs_max) {
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
const int num = in_tensor.numel();
auto in_dims = in_tensor.dims();
int channel = in_dims[quant_axis];
const T* in_data = in_tensor.data<T>();
if (quant_axis == 0) {
int grid = channel;
FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
in, num, channel, out);
int block = 1024;
FindChannelAbsMaxKernelQuantAxis0<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, channel, out_abs_max);
} else if (quant_axis == 1) {
int grid = in_dims[1];
int block = in_dims[0];
FindChannelAbsMaxKernelQuantAxis1<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, in_dims[0], in_dims[1], out_abs_max);
}
}
};
@ -189,10 +244,12 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
float>;
// ChannelClipAndQuantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n,
const int c, T* out) {
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const int bin_cnt,
const int n, const int c,
T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
@ -211,22 +268,57 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
}
}
// ChannelClipAndQuantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale,
const int bin_cnt,
const int n, const int cin,
const int cout, T* out) {
T s = scale[blockIdx.x % cout];
T inv_s = inverse(s);
int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v);
}
}
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = channel;
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
int num = in.numel();
auto in_dims = in.dims();
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, channel, out_data);
if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024;
ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
} else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1];
int block = 1024;
ChannelClipAndQuantKernelQuantAxis1<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
}
}
};

@ -61,15 +61,15 @@ struct FindRangeAbsMaxFunctor {
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num,
const int channel, T* out);
void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor,
const int quant_axis, T* out_abs_max);
};
template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int channel, framework::Tensor* out);
const int quant_axis, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
@ -144,12 +144,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis");
auto& dev_ctx = context.template device_context<DeviceContext>();
FindChannelAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, in->data<T>(), in->numel(), in->dims()[0], out_scale_data);
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis,
out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out);
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out);
}
};

@ -29,6 +29,7 @@ from .quantization_pass import _out_scale_op_list
from .quantization_pass import _get_op_input_var_names
from .quantization_pass import _get_op_output_var_names
from .quantization_pass import _get_output_name_index
from .quantization_pass import _channelwise_quant_axis1_ops
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
@ -316,6 +317,7 @@ class PostTrainingQuantization(object):
self._out_scale_op_list = _out_scale_op_list
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self.weight_op_pairs = {}
self._sampling_data = {}
self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
@ -436,6 +438,8 @@ class PostTrainingQuantization(object):
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
graph = _remove_ctrl_vars(graph)
graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
graph = _apply_pass(self._scope, graph, 'depthwise_conv_bn_fuse_pass')
graph = _apply_pass(self._scope, graph, 'conv_transpose_bn_fuse_pass')
self._program = graph.to_program()
def _collect_target_varnames(self):
@ -446,10 +450,11 @@ class PostTrainingQuantization(object):
# TODO(juncaipeng), consider the name_scope of skip_quant
_logger.info("Collect quantized variable names ...")
def collect_var_name(var_name_list, persistable_var_names):
def collect_var_name(var_name_list, persistable_var_names, op_type):
for var_name in var_name_list:
if var_name in persistable_var_names:
self._quantized_weight_var_name.add(var_name)
self.weight_op_pairs[var_name] = op_type
else:
self._quantized_act_var_name.add(var_name)
@ -462,13 +467,15 @@ class PostTrainingQuantization(object):
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type:
collect_var_name(
_get_op_input_var_names(op), persistable_var_names)
_get_op_input_var_names(op), persistable_var_names, op_type)
collect_var_name(
_get_op_output_var_names(op), persistable_var_names)
_get_op_output_var_names(op), persistable_var_names,
op_type)
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
collect_var_name(
_get_op_output_var_names(op), persistable_var_names)
_get_op_output_var_names(op), persistable_var_names,
op_type)
def _set_activation_persistable(self):
'''
@ -492,35 +499,65 @@ class PostTrainingQuantization(object):
Sample the input threshold(min, max, or abs_max) in every iterations.
'''
assert self._algo in ["abs_max", "min_max"], \
"The algo should be abs_max or min_max to sample min max value."
"The algo should be abs_max or min_max for _sample_threshold."
if self._algo == "abs_max":
self._sample_threshold_abs_max()
elif self._algo == "min_max":
self._sample_threshold_min_max()
def _sample_threshold_abs_max(self):
assert self._algo == "abs_max", \
"The algo should be abs_max for _sample_threshold_abs_max."
# Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_per_channel = []
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self.weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_per_channel.append(
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_per_channel
self._quantized_var_abs_max[var_name] = abs_max_value
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \
(abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value
elif self._algo == "min_max":
def _sample_threshold_min_max(self):
assert self._algo == "min_max", \
"The algo should be min_max for _sample_threshold_min_max."
if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
min_per_channel = []
max_per_channle = []
if self._weight_quantize_type == "abs_max":
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
elif self._weight_quantize_type == "channel_wise_abs_max":
min_value = []
max_value = []
if self.weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
min_value.append(float(np.min(var_tensor[:, i])))
max_value.append(float(np.max(var_tensor[:, i])))
else:
for i in range(var_tensor.shape[0]):
min_per_channel.append(float(np.min(var_tensor[i])))
max_per_channle.append(float(np.max(var_tensor[i])))
self._quantized_var_min[var_name] = min_per_channel
self._quantized_var_max[var_name] = max_per_channle
min_value.append(float(np.min(var_tensor[i])))
max_value.append(float(np.max(var_tensor[i])))
self._quantized_var_min[var_name] = min_value
self._quantized_var_max[var_name] = max_value
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
min_value = float(np.min(var_tensor))
@ -554,11 +591,6 @@ class PostTrainingQuantization(object):
applied in every iteration.
'''
assert self._algo == "KL", "The algo should be KL to sample data."
for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data:
var_tensor = _load_variable_data(self._scope, var_name)
self._sampling_data[var_name] = var_tensor
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
@ -584,15 +616,20 @@ class PostTrainingQuantization(object):
# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
weight_data = self._sampling_data[var_name]
weight_threshold = None
weight_data = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
weight_threshold = np.max(np.abs(weight_data))
weight_threshold = float(np.max(np.abs(weight_data)))
elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = []
if self.weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(weight_data.shape[1]):
weight_threshold.append(
float(np.max(np.abs(weight_data[:, i]))))
else:
for i in range(weight_data.shape[0]):
abs_max_value = np.max(np.abs(weight_data[i]))
weight_threshold.append(abs_max_value)
weight_threshold.append(
float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold
# KL threshold for activations

@ -33,34 +33,29 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"
def residual_block(img, label, num=1):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
use_cudnn=False,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
hidden = img
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 20, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 20, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu")
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
def pact(x, name=None):
@ -102,7 +97,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
img.stop_gradient = False
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = residual_block(img, label, 1)
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.SGD(learning_rate=0.0001)
opt.minimize(loss)

@ -31,45 +31,45 @@ def dequantize_max_abs(x, scale, max_range):
return y
def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False):
def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
scales = []
if not use_second_dim:
for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i, scale in enumerate(scales):
y[i] = np.round(x[i] / scale * max_range)
else:
for i in range(x.shape[0]):
s = []
for j in range(x.shape[1]):
s.append(np.max(np.abs(x[i][j])).astype("float32"))
scales.append(s)
scales = np.amax(np.array(scales), axis=0)
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
if quant_axis == 0:
for i in range(x.shape[0]):
for j, scale in enumerate(scales):
y[i][j] = np.round(x[i][j] / scale * max_range)
scale = np.max(np.abs(x[i])).astype("float32")
scales.append(scale)
y[i] = np.round(x[i] * max_range / scale)
elif quant_axis == 1:
for i in range(x.shape[1]):
scale = np.max(np.abs(x[:, i])).astype("float32")
scales.append(scale)
y[:, i] = np.round(x[:, i] * max_range / scale)
return y, scales
def channel_wise_dequantize_max_abs(x,
scales,
quant_bits,
quant_axis,
activation_scale=None):
if activation_scale is None:
y = x.copy()
for i in range(x.shape[0]):
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i]
assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
if isinstance(quant_bits, list):
max_range = math.pow(2, quant_bits[0] - 1) - 1
else:
max_range = math.pow(2, quant_bits - 1) - 1
y = x.copy()
if quant_axis == 0:
for i in range(x.shape[0]):
for j in range(x.shape[1]):
y[i][j] = (scales[j] /
(math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j]
y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
y[i] = x[i] * scales[i] / max_range
elif quant_axis == 1:
for i in range(x.shape[1]):
y[:, i] = x[:, i] * scales[i] / max_range
if activation_scale is not None:
y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
return y
@ -83,9 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(
x, self.quant_bits[0], use_second_dim=True)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1,
self.activation_scale)
self.inputs = {
@ -105,25 +104,39 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 0
def setUp(self):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0])
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0],
self.quant_axis)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.quant_axis)
self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type))]
}
self.attrs = {'quant_bits': self.quant_bits}
self.attrs = {
'quant_bits': self.quant_bits,
'quant_axis': self.quant_axis
}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 1
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8

@ -72,28 +72,62 @@ class TestFakeQuantizeOp2(OpTest):
class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self):
self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
self.op_type = "fake_channel_wise_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {
'X': np.random.random((4, 3, 64, 64)).astype("float32"),
}
self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}
scales = []
for i in range(self.inputs['X'].shape[0]):
scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32"))
outputs = self.inputs['X'].copy()
for i, scale in enumerate(scales):
outputs[i] = np.round(outputs[i] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1))
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
if self.quant_axis == 0:
for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
scales.append(scale_v)
outputs[i] = np.round(outputs[i] / scale_v * bnt)
elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
"float32")
scales.append(scale_v)
outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt)
self.outputs = {
'Out': outputs,
'OutScale': np.array(scales).astype("float32"),
}
def set_arg(self):
self.quant_axis = 0
self.inputs = {
'X': np.random.random((20, 15, 6, 6)).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 1
self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype("float32"),
}
class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"

Loading…
Cancel
Save