From 545247d7b4e803a2067c0187b2c3c962ec22629d Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 4 Mar 2019 17:59:31 +0800 Subject: [PATCH 1/4] add channel wise quantize op. --- paddle/fluid/operators/fake_quantize_op.cc | 62 +++++++++++++++++++ paddle/fluid/operators/fake_quantize_op.cu | 2 + paddle/fluid/operators/fake_quantize_op.h | 33 ++++++++++ .../tests/unittests/test_fake_quantize_op.py | 24 +++++++ 4 files changed, 121 insertions(+) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 3bb07d3835..c873ee6718 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -134,6 +134,61 @@ $$Out = round(X/scale * range)$$ } }; +class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FakeChannelWiseQuantizeOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FakeChannelWiseQuantizeOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("OutScales"), + "Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class FakeChannelWiseQuantizeAbsMaxOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input is float data type."); + AddOutput("Out", + "(Tensor) Output of quantized low level tensor, " + "but also saved as float data type."); + AddOutput("OutScales", "(Tensor) Current channel wise scale"); + AddAttr("bit_length", "(int, default 8)") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'bit_length' should be between 1 and 16."); + }); + AddComment(R"DOC( +The scale of FakeChannelWiseQuantize operator is a vector. +In detail, each channel of the input X has a scale value. + +$$scale_c = max(abs(X_c))$$ +$$range = 2^{bit_length - 1} - 1$$ +$$Out_c = round(X_c / scale_c * range)$$ + +In above three formulas, the range value of c is as follow: +$$0 \leq c \leq \ the\ channel\ number\ of\ X$$ +)DOC"); + } +}; + class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { public: FakeQuantizeRangeAbsMaxOp(const std::string& type, @@ -218,3 +273,10 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel); + +REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxOp, + ops::FakeChannelWiseQuantizeAbsMaxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index a0ff639621..5da16a7c73 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -174,5 +174,7 @@ namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 7ace7573ec..8b47600e7d 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -63,6 +63,39 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel { } }; +template +class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + + auto* out = context.Output("Out"); + auto* out_scales = context.Output("OutScales"); + T* out_scales_data = out_scales->mutable_data(context.GetPlace()); + out->mutable_data(context.GetPlace()); + + int bit_length = context.Attr("bit_length"); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + + auto& dev_ctx = context.template device_context(); + auto find_abs_max = FindAbsMaxFunctor(); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel = in->Slice(i, i + 1); + const T* one_channel_data = one_channel.data(); + find_abs_max(dev_ctx, one_channel_data, one_channel.numel(), + &out_scales_data[i]); + } + auto clip_quant = ClipAndFakeQuantFunctor(); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1); + clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt, + &one_channel_out); + } + } +}; + template class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 4582b2a0ee..90a90112bd 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -35,6 +35,30 @@ class TestFakeQuantizeOp(OpTest): self.check_output() +class TestFakeChannelWiseQuantizeOp(OpTest): + def setUp(self): + self.op_type = "fake_channel_wise_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = { + 'X': np.random.random((4, 3, 64, 64)).astype("float32"), + } + scales = [] + for i in range(self.inputs['X'].shape[0]): + scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32")) + outputs = self.inputs['X'].copy() + for i, scale in enumerate(scales): + outputs[i] = np.round(outputs[i] / scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)) + + self.outputs = { + 'Out': outputs, + 'OutScales': np.array(scales).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + class TestFakeQuantizeRangeAbsMaxOp(OpTest): def setUp(self): self.op_type = "fake_quantize_range_abs_max" From 89dee160d18d699075c2bfbfce6d7311dfa4f59f Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 5 Mar 2019 16:41:46 +0800 Subject: [PATCH 2/4] add channel wise dequantize op. --- paddle/fluid/operators/fake_dequantize_op.cc | 72 +++++++++++++++++++ paddle/fluid/operators/fake_dequantize_op.cu | 4 ++ paddle/fluid/operators/fake_dequantize_op.h | 51 +++++++++++++ paddle/fluid/operators/fake_quantize_op.cc | 7 +- .../unittests/test_fake_dequantize_op.py | 71 ++++++++++++++++++ 5 files changed, 201 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 5d6488c67e..73ffaae6a5 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -76,6 +76,70 @@ $$Out = \frac{scale*X}{ max_range }$$ } }; +class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("WeightScales"), + "Input(WeightScales) of FakeChannelWiseDequantizeMaxAbsOp " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null."); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class FakeChannelWiseDequantizeMaxAbsOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) The input with float-32/64 type is the " + "low precision tensor."); + AddInput("ActivationScale", + "(float) The activation scale in quantization stage.") + .AsDispensable(); + AddInput("WeightScales", + "(float array) The weight scales in quantization stage."); + AddOutput("Out", + "(Tensor) The output is the dequantized high " + "precision tensor."); + AddAttr("activation_bits", "Quantization bit number for activation.") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'activation_bits' should be between 1 and 16."); + }); + AddAttr("weight_bits", "Quantization bit number for weights.") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'weight_bits' should be between 1 and 16."); + }); + + AddComment(R"DOC( +FakeChannelWiseDequantizeMaxAbsOp operator. + +This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp: + +$$Out_c = \frac{ActivationScale*WeightScale_c*X_c}{(2^{weight\_bits-1}-1)*(2^{activation\_bits-1}-1)}$$ + +In the above formula, the range value of c is as follow: +$$0 \leq c \lt \ the\ channel\ number\ of\ X$$ + +Notes: Tha per-channel quantization is only applied to weights(channel size scale). +And the activations use per-layer quantization(only one scale). +)DOC"); + } +}; + } // namespace operators } // namespace paddle @@ -88,3 +152,11 @@ REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp, REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsKernel, ops::FakeDequantizeMaxAbsKernel); + +REGISTER_OPERATOR(fake_channel_wise_dequantize_max_abs, + ops::FakeChannelWiseDequantizeMaxAbsOp, + ops::FakeChannelWiseDequantizeMaxAbsOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fake_channel_wise_dequantize_max_abs, + ops::FakeChannelWiseDequantizeMaxAbsKernel, + ops::FakeChannelWiseDequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 225bcc45bc..35dcc69279 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -55,3 +55,7 @@ using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsKernel, ops::FakeDequantizeMaxAbsKernel); +REGISTER_OP_CUDA_KERNEL( + fake_channel_wise_dequantize_max_abs, + ops::FakeChannelWiseDequantizeMaxAbsKernel, + ops::FakeChannelWiseDequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index d9923a10da..c26dfa8332 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -45,5 +45,56 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel { } }; +template +class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto* in = ctx.Input("X"); + auto* weight_scales = ctx.Input("WeightScales"); + auto* out = ctx.Output("Out"); + + PADDLE_ENFORCE_EQ(weight_scales->numel(), in->dims()[0], + "The weight uses the per-channel quantization type, so " + "the number of weight scale values must be the same with " + "first dimension value of Input(X)."); + + int ativation_bits = ctx.Attr("activation_bits"); + int weight_bits = ctx.Attr("weight_bits"); + int range = std::pow(2, weight_bits - 1) - 1; + + auto& dev_ctx = ctx.template device_context(); + out->mutable_data(dev_ctx.GetPlace()); + + auto dequant = DequantizeFunctor(); + if (ctx.HasInput("ActivationScale")) { + auto* activation_scale = ctx.Input("ActivationScale"); + PADDLE_ENFORCE_EQ(activation_scale->numel(), 1, + "The activation uses per-layer quantization type, so " + "it must have only one value."); + framework::Tensor cpu_weigth_scales; + framework::TensorCopy(*weight_scales, platform::CPUPlace(), + &cpu_weigth_scales); + dev_ctx.Wait(); + const T* weight_scales_data = cpu_weigth_scales.data(); + range *= (std::pow(2, ativation_bits - 1) - 1); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto max_range = range / weight_scales_data[i]; + dequant(dev_ctx, &one_channel_in, activation_scale, + static_cast(max_range), &one_channel_out); + } + } else { + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + framework::Tensor one_channel_scale = weight_scales->Slice(i, i + 1); + dequant(dev_ctx, &one_channel_in, &one_channel_scale, + static_cast(range), &one_channel_out); + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index c873ee6718..70186e5efa 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -180,11 +180,10 @@ The scale of FakeChannelWiseQuantize operator is a vector. In detail, each channel of the input X has a scale value. $$scale_c = max(abs(X_c))$$ -$$range = 2^{bit_length - 1} - 1$$ -$$Out_c = round(X_c / scale_c * range)$$ - +$$range = 2^{bit\_length - 1} - 1$$ +$$Out_c = round(\frac{X_c * range} {scale_c})$$ In above three formulas, the range value of c is as follow: -$$0 \leq c \leq \ the\ channel\ number\ of\ X$$ +$$0 \leq c \lt \ the\ channel\ number\ of\ X$$ )DOC"); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 1bb4662e8d..bd8dad4d59 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,6 +31,77 @@ def dequantize_max_abs(x, scale, max_range): return y +def channel_wise_quantize_max_abs(x, max_range): + scales = [] + for i in range(x.shape[0]): + scales.append(np.max(np.abs(x[i])).astype("float32")) + + y = x.copy() + for i, scale in enumerate(scales): + y[i] = np.round(y[i] / scale * max_range) + return y, scales + + +def channel_wise_dequantize_max_abs(x, scales, max_range): + y = x.copy() + for i in range(x.shape[0]): + y[i] = (scales[i] / max_range) * y[i] + return y + + +class TestFakeChannelWiseDequantizeMaxAbsOp(OpTest): + def set_args(self): + self.weight_bits = 8 + self.activation_bits = 2 + self.data_type = "float32" + + def setUp(self): + self.set_args() + self.op_type = "fake_channel_wise_dequantize_max_abs" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + max_range = math.pow(2, self.weight_bits - 1) - 1 + yq, scales = channel_wise_quantize_max_abs(x, max_range) + ydq = channel_wise_dequantize_max_abs(yq, scales, max_range) + + self.inputs = { + 'X': yq, + 'ActivationScale': np.array(1.0).astype(self.data_type), + 'WeightScales': np.array(scales).astype(self.data_type) + } + self.attrs = { + 'weight_bits': self.weight_bits, + 'activation_bits': self.activation_bits + } + self.outputs = {'Out': ydq} + + def test_check_output(self): + self.check_output() + + +class TestFakeChannelWiseDequantizeMaxAbsOpNoActivationScale(OpTest): + def set_args(self): + self.weight_bits = 8 + self.data_type = "float32" + + def setUp(self): + self.set_args() + self.op_type = "fake_channel_wise_dequantize_max_abs" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + max_range = math.pow(2, self.weight_bits - 1) - 1 + yq, scales = channel_wise_quantize_max_abs(x, max_range) + ydq = channel_wise_dequantize_max_abs(yq, scales, max_range) + + self.inputs = { + 'X': yq, + 'WeightScales': np.array(scales).astype(self.data_type) + } + self.attrs = {'weight_bits': self.weight_bits} + self.outputs = {'Out': ydq} + + def test_check_output(self): + self.check_output() + + class TestFakeDequantizeMaxAbsOp(OpTest): def set_args(self): self.num_bits = 8 From 806832e09163500fa01b8e9eabb871424dc26dbd Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 5 Mar 2019 20:15:41 +0800 Subject: [PATCH 3/4] update the input format of channel wise dequantize op. --- paddle/fluid/operators/fake_dequantize_op.cc | 42 ++++++++----------- paddle/fluid/operators/fake_dequantize_op.h | 38 +++++++---------- .../unittests/test_fake_dequantize_op.py | 27 ++++++------ 3 files changed, 46 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 73ffaae6a5..68c7227e5a 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_dequantize_op.h" #include +#include namespace paddle { namespace operators { @@ -84,8 +85,8 @@ class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasInput("X"), "Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightScales"), - "Input(WeightScales) of FakeChannelWiseDequantizeMaxAbsOp " + PADDLE_ENFORCE(ctx->HasInputs("Scales"), + "Input(Scales) of FakeChannelWiseDequantizeMaxAbsOp " "should not be null."); PADDLE_ENFORCE( ctx->HasOutput("Out"), @@ -103,39 +104,32 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker AddInput("X", "(Tensor) The input with float-32/64 type is the " "low precision tensor."); - AddInput("ActivationScale", - "(float) The activation scale in quantization stage.") - .AsDispensable(); - AddInput("WeightScales", - "(float array) The weight scales in quantization stage."); + AddInput("Scales", + "(Tensors) The scales in quantization stage. " + "Now, `Scales` is a vector with at most two tensors. " + "If Scales has two elements, the second tensor should only have " + "one value.") + .AsDuplicable(); AddOutput("Out", "(Tensor) The output is the dequantized high " "precision tensor."); - AddAttr("activation_bits", "Quantization bit number for activation.") - .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, - "'activation_bits' should be between 1 and 16."); - }); - AddAttr("weight_bits", "Quantization bit number for weights.") - .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, - "'weight_bits' should be between 1 and 16."); - }); + AddAttr>( + "quant_bits", + "Quantization bit numbers in quantization stage. " + "The size of `quant_bits` should be equal to the size of `Scales`.") + .SetDefault({8}); AddComment(R"DOC( FakeChannelWiseDequantizeMaxAbsOp operator. This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp: -$$Out_c = \frac{ActivationScale*WeightScale_c*X_c}{(2^{weight\_bits-1}-1)*(2^{activation\_bits-1}-1)}$$ +$$Out_c = \frac{X_c\prod_{i=1}^{n}Scales_{ic}}{\prod_{i=1}^{n}(2^{quant\_bits_i-1}-1)}$$ -In the above formula, the range value of c is as follow: -$$0 \leq c \lt \ the\ channel\ number\ of\ X$$ +In the above formula, the range value of $c$ can be represented as $0 \leq c \lt \ the\ channel\ number\ of\ X$. +Besides, the size of $quant\_bits$ should be equal to the size of $Scales$, and it is called $n$ in the formula. -Notes: Tha per-channel quantization is only applied to weights(channel size scale). -And the activations use per-layer quantization(only one scale). +Notes: In general, the per-channel quantization is only applied to weights and the activations use per-layer quantization. )DOC"); } }; diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index c26dfa8332..549f5039f4 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -50,47 +51,40 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { auto* in = ctx.Input("X"); - auto* weight_scales = ctx.Input("WeightScales"); + auto scales = ctx.MultiInput("Scales"); auto* out = ctx.Output("Out"); - PADDLE_ENFORCE_EQ(weight_scales->numel(), in->dims()[0], - "The weight uses the per-channel quantization type, so " - "the number of weight scale values must be the same with " + PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0], + "The number of first scale values must be the same with " "first dimension value of Input(X)."); - int ativation_bits = ctx.Attr("activation_bits"); - int weight_bits = ctx.Attr("weight_bits"); - int range = std::pow(2, weight_bits - 1) - 1; + auto quant_bits = ctx.Attr>("quant_bits"); + int max_range = std::pow(2, quant_bits[0] - 1) - 1; auto& dev_ctx = ctx.template device_context(); out->mutable_data(dev_ctx.GetPlace()); auto dequant = DequantizeFunctor(); - if (ctx.HasInput("ActivationScale")) { - auto* activation_scale = ctx.Input("ActivationScale"); - PADDLE_ENFORCE_EQ(activation_scale->numel(), 1, - "The activation uses per-layer quantization type, so " - "it must have only one value."); - framework::Tensor cpu_weigth_scales; - framework::TensorCopy(*weight_scales, platform::CPUPlace(), - &cpu_weigth_scales); - dev_ctx.Wait(); - const T* weight_scales_data = cpu_weigth_scales.data(); - range *= (std::pow(2, ativation_bits - 1) - 1); + if (scales.size() == 2) { + PADDLE_ENFORCE_EQ( + scales[1]->numel(), 1, + "The second scale tensor should only have one value at now."); for (int64_t i = 0; i < in->dims()[0]; i++) { framework::Tensor one_channel_in = in->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto max_range = range / weight_scales_data[i]; - dequant(dev_ctx, &one_channel_in, activation_scale, + framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); + max_range *= (std::pow(2, quant_bits[1] - 1) - 1); + dequant(dev_ctx, &one_channel_in, &one_channel_scale, static_cast(max_range), &one_channel_out); } + dequant(dev_ctx, out, scales[1], static_cast(1), out); } else { for (int64_t i = 0; i < in->dims()[0]; i++) { framework::Tensor one_channel_in = in->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = weight_scales->Slice(i, i + 1); + framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(range), &one_channel_out); + static_cast(max_range), &one_channel_out); } } } diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index bd8dad4d59..8d91d8fd1d 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -49,53 +49,50 @@ def channel_wise_dequantize_max_abs(x, scales, max_range): return y -class TestFakeChannelWiseDequantizeMaxAbsOp(OpTest): +class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): def set_args(self): - self.weight_bits = 8 - self.activation_bits = 2 + self.quant_bits = [8, 2] self.data_type = "float32" def setUp(self): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - max_range = math.pow(2, self.weight_bits - 1) - 1 + max_range = math.pow(2, self.quant_bits[0] - 1) - 1 + max_range *= (math.pow(2, self.quant_bits[1] - 1) - 1) yq, scales = channel_wise_quantize_max_abs(x, max_range) ydq = channel_wise_dequantize_max_abs(yq, scales, max_range) self.inputs = { 'X': yq, - 'ActivationScale': np.array(1.0).astype(self.data_type), - 'WeightScales': np.array(scales).astype(self.data_type) - } - self.attrs = { - 'weight_bits': self.weight_bits, - 'activation_bits': self.activation_bits + 'Scales': [("scales0", np.array(scales).astype(self.data_type)), + ("scales1", np.array([1.0]).astype(self.data_type))] } + self.attrs = {'quant_bits': self.quant_bits} self.outputs = {'Out': ydq} def test_check_output(self): self.check_output() -class TestFakeChannelWiseDequantizeMaxAbsOpNoActivationScale(OpTest): +class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): def set_args(self): - self.weight_bits = 8 + self.quant_bits = [8] self.data_type = "float32" def setUp(self): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - max_range = math.pow(2, self.weight_bits - 1) - 1 + max_range = math.pow(2, self.quant_bits[0] - 1) - 1 yq, scales = channel_wise_quantize_max_abs(x, max_range) ydq = channel_wise_dequantize_max_abs(yq, scales, max_range) self.inputs = { 'X': yq, - 'WeightScales': np.array(scales).astype(self.data_type) + 'Scales': [("scales0", np.array(scales).astype(self.data_type))] } - self.attrs = {'weight_bits': self.weight_bits} + self.attrs = {'quant_bits': self.quant_bits} self.outputs = {'Out': ydq} def test_check_output(self): From 8063b31e2d485b665303a2010e63909ba53d1664 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 5 Mar 2019 22:54:22 +0800 Subject: [PATCH 4/4] Reduce redundant code for channel wise dequant op. test=develop --- paddle/fluid/operators/fake_dequantize_op.h | 27 +++++++---------- .../unittests/test_fake_dequantize_op.py | 30 +++++++++++-------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index 549f5039f4..d05f203853 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -65,27 +65,20 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { out->mutable_data(dev_ctx.GetPlace()); auto dequant = DequantizeFunctor(); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); + dequant(dev_ctx, &one_channel_in, &one_channel_scale, + static_cast(max_range), &one_channel_out); + } + if (scales.size() == 2) { PADDLE_ENFORCE_EQ( scales[1]->numel(), 1, "The second scale tensor should only have one value at now."); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); - max_range *= (std::pow(2, quant_bits[1] - 1) - 1); - dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(max_range), &one_channel_out); - } - dequant(dev_ctx, out, scales[1], static_cast(1), out); - } else { - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); - dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(max_range), &one_channel_out); - } + max_range = std::pow(2, quant_bits[1] - 1) - 1; + dequant(dev_ctx, out, scales[1], static_cast(max_range), out); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 8d91d8fd1d..32cb23cbfa 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,42 +31,49 @@ def dequantize_max_abs(x, scale, max_range): return y -def channel_wise_quantize_max_abs(x, max_range): +def channel_wise_quantize_max_abs(x, quant_bit=8): scales = [] for i in range(x.shape[0]): scales.append(np.max(np.abs(x[i])).astype("float32")) y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 for i, scale in enumerate(scales): y[i] = np.round(y[i] / scale * max_range) return y, scales -def channel_wise_dequantize_max_abs(x, scales, max_range): +def channel_wise_dequantize_max_abs(x, + scales, + quant_bits, + activation_scale=None): y = x.copy() for i in range(x.shape[0]): - y[i] = (scales[i] / max_range) * y[i] + y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i] + if activation_scale is not None: + y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) return y class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): def set_args(self): - self.quant_bits = [8, 2] + self.quant_bits = [8, 8] self.data_type = "float32" + self.activation_scale = 0.7861 def setUp(self): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - max_range = math.pow(2, self.quant_bits[0] - 1) - 1 - max_range *= (math.pow(2, self.quant_bits[1] - 1) - 1) - yq, scales = channel_wise_quantize_max_abs(x, max_range) - ydq = channel_wise_dequantize_max_abs(yq, scales, max_range) + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + self.activation_scale) self.inputs = { 'X': yq, 'Scales': [("scales0", np.array(scales).astype(self.data_type)), - ("scales1", np.array([1.0]).astype(self.data_type))] + ("scales1", np.array( + [self.activation_scale]).astype(self.data_type))] } self.attrs = {'quant_bits': self.quant_bits} self.outputs = {'Out': ydq} @@ -84,9 +91,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - max_range = math.pow(2, self.quant_bits[0] - 1) - 1 - yq, scales = channel_wise_quantize_max_abs(x, max_range) - ydq = channel_wise_dequantize_max_abs(yq, scales, max_range) + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits) self.inputs = { 'X': yq,