add channel wise quantization in ir pass.

revert-16190-refine_parallel_executor
Zhen Wang 6 years ago
parent 81b4fad8b9
commit ec88b6cc5a

@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
@ -54,10 +55,6 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto scales = ctx.MultiInput<framework::Tensor>("Scales");
auto* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X).");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
int max_range = std::pow(2, quant_bits[0] - 1) - 1;
@ -65,6 +62,12 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
out->mutable_data<T>(dev_ctx.GetPlace());
auto dequant = DequantizeFunctor<DeviceContext, T>();
if (scales.size() == 1) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only one "
"element.");
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
@ -72,8 +75,25 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
if (scales.size() == 2) {
} else if (scales.size() == 2) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements.");
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int64_t j = 0; j < in->dims()[1]; j++) {
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
}
PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1,
"The second scale tensor should only have one value at now.");

@ -169,10 +169,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScales"),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null.");
ctx->HasOutput("OutScale"),
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]});
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out");
}
@ -192,7 +192,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddOutput("Out",
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type.");
AddOutput("OutScales", "(Tensor) Current channel wise scale");
AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {

@ -78,8 +78,8 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scales = context.Output<framework::Tensor>("OutScales");
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace());
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length");
@ -91,13 +91,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
framework::Tensor one_channel = in->Slice(i, i + 1);
const T* one_channel_data = one_channel.data<T>();
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scales_data[i]);
&out_scale_data[i]);
}
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scale->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out);
}

@ -243,7 +243,12 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
scope=scope, place=place, activation_quantize_type=quant_type)
scope=scope,
place=place,
activation_quantize_type=quant_type,
weight_quantize_type='channel_wise_abs_max')
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
@ -296,7 +301,11 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list=[loss, w_var])
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type='channel_wise_abs_max')
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
if not for_ci:
marked_nodes = set()
@ -375,29 +384,32 @@ class TestQuantizationFreezePass(unittest.TestCase):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(
True, seed=1, quant_type='abs_max', for_ci=True)
True, seed=1, quant_type='abs_max', for_ci=False)
def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True)
self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=False)
def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(
True, seed=1, quant_type='range_abs_max', for_ci=True)
True, seed=1, quant_type='range_abs_max', for_ci=False)
self.freeze_graph(
True,
seed=1,
quant_type='moving_average_abs_max',
for_ci=True)
for_ci=False)
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
self.freeze_graph(
False, seed=2, quant_type='range_abs_max', for_ci=True)
False, seed=2, quant_type='range_abs_max', for_ci=False)
self.freeze_graph(
False, seed=2, quant_type='moving_average_abs_max', for_ci=True)
False,
seed=2,
quant_type='moving_average_abs_max',
for_ci=False)
if __name__ == '__main__':

@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range):
return y
def channel_wise_quantize_max_abs(x, quant_bit=8):
def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False):
scales = []
if not use_second_dim:
for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i, scale in enumerate(scales):
y[i] = np.round(y[i] / scale * max_range)
y[i] = np.round(x[i] / scale * max_range)
else:
for i in range(x.shape[0]):
s = []
for j in range(x.shape[1]):
s.append(np.max(np.abs(x[i][j])).astype("float32"))
scales.append(s)
scales = np.amax(np.array(scales), axis=0)
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i in range(x.shape[0]):
for j, scale in enumerate(scales):
y[i][j] = np.round(x[i][j] / scale * max_range)
return y, scales
@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x,
scales,
quant_bits,
activation_scale=None):
if activation_scale is None:
y = x.copy()
for i in range(x.shape[0]):
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i]
else:
y = x.copy()
for i in range(x.shape[0]):
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i]
if activation_scale is not None:
for j in range(x.shape[1]):
y[i][j] = (scales[j] /
(math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j]
y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
return y
@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0])
yq, scales = channel_wise_quantize_max_abs(
x, self.quant_bits[0], use_second_dim=True)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.activation_scale)

@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest):
self.outputs = {
'Out': outputs,
'OutScales': np.array(scales).astype("float32"),
'OutScale': np.array(scales).astype("float32"),
}
def test_check_output(self):

Loading…
Cancel
Save