Add fake_quantize_op. (#11359)
* Add a fake_quantize_op, which quantize an input tensor to a tensor with lower bits.guochaorong-patch-1
parent
79d797fde9
commit
8e4b225fe4
@ -0,0 +1,112 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fake_quantize_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FakeQuantizeOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
FakeQuantizeOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of FakeQuantizeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FakeQuantizeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
|
||||
"OutMovingScale(Out) of FakeQuantizeOp should not be null");
|
||||
// if (ctx->HasInput("InMovingScale")) {
|
||||
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
|
||||
//}
|
||||
// if (ctx->HasInput("InScales")) {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
|
||||
"OutScales(Out) of FakeQuantizeOp should not be null");
|
||||
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
|
||||
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
|
||||
// ctx->Outputs("OutScales")[0],
|
||||
// "Mean and MeanOut should share the same memory");
|
||||
//}
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) Input tensor of scale operator.");
|
||||
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddInput("InMovingScale", "Last scale, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddInput("InCurrentIter",
|
||||
"Last iteration number, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
|
||||
AddOutput("OutScales",
|
||||
"(Tensor) scale buffer, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddOutput("OutMovingScale", " Current scale");
|
||||
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
|
||||
AddAttr<std::string>("quantize_type",
|
||||
"(string, default abs_max)"
|
||||
"The scaling tpe of the quantize operator.")
|
||||
.SetDefault("abs_max");
|
||||
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
|
||||
AddAttr<int>("bit_length", "(int, default 8)")
|
||||
.SetDefault(8)
|
||||
.AddCustomChecker([](const int &bit_length) {
|
||||
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
|
||||
"'bit_length' should be between 1 and 16.");
|
||||
});
|
||||
AddAttr<bool>("is_test", "").SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
FakeQuantize operator
|
||||
|
||||
quantize_type = abs_max:
|
||||
|
||||
$$scale = max(abs(x))$$
|
||||
|
||||
quantize_type = range_abs_max:
|
||||
|
||||
$$scale = max(max(abs(x)), history_abs_max)$$
|
||||
|
||||
quantize_type = moving_average_abs_max:
|
||||
|
||||
$$scale = 0.1*scale+0.9*new_abs_max)$$
|
||||
|
||||
$$Out = scale*X$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fake_quantize,
|
||||
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,155 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/clip_op.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using platform::Transform;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FakeQuantizeKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
T FindAbsMax(framework::Tensor* in, int n) const {
|
||||
T* p = in->mutable_data<T>(platform::CPUPlace());
|
||||
T abs_max = (T)0.00000001;
|
||||
for (int i = 0; i < n; i++) {
|
||||
T tmp = fabs(p[i]);
|
||||
if (tmp > abs_max) abs_max = tmp;
|
||||
}
|
||||
return T(abs_max);
|
||||
}
|
||||
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
|
||||
const T& cur_scale, int window_size,
|
||||
int current_iter) const {
|
||||
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
|
||||
T remove_tmp = sl[current_iter];
|
||||
sl[current_iter] = cur_scale;
|
||||
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
|
||||
if (max_scale < cur_scale) {
|
||||
max_scale = cur_scale;
|
||||
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
|
||||
int size = (current_iter > window_size) ? window_size : current_iter;
|
||||
max_scale = T(FindAbsMax(scale_list, size));
|
||||
}
|
||||
return max_scale;
|
||||
}
|
||||
|
||||
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
|
||||
framework::Tensor* out_scale,
|
||||
const T& cur_scale) const {
|
||||
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
|
||||
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
|
||||
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
|
||||
return T(outs[0]);
|
||||
}
|
||||
|
||||
virtual void Compute(const framework::ExecutionContext& context) const {
|
||||
auto* tensor = context.Output<framework::Tensor>("Out");
|
||||
auto* in = context.Input<framework::Tensor>("X");
|
||||
const bool is_test = context.Attr<bool>("is_test");
|
||||
tensor->mutable_data<T>(in->place());
|
||||
|
||||
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
|
||||
oms_tensor->mutable_data<T>(in->place());
|
||||
|
||||
auto quantize_type =
|
||||
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
|
||||
if (quantize_type == std::string("range_abs_max")) {
|
||||
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
|
||||
oss_tensor->mutable_data<T>(
|
||||
context.Input<framework::Tensor>("InScales")->place());
|
||||
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
|
||||
oci_tensor->mutable_data<T>(
|
||||
context.Input<framework::Tensor>("InCurrentIter")->place());
|
||||
}
|
||||
|
||||
T scale = static_cast<T>(1);
|
||||
int window_size = context.Attr<int>("window_size");
|
||||
int bit_length = context.Attr<int>("bit_length");
|
||||
int bin_cnt = std::pow(2, bit_length - 1) - 1;
|
||||
|
||||
auto& dev =
|
||||
*context.template device_context<DeviceContext>().eigen_device();
|
||||
auto raw_in = framework::EigenVector<T>::Flatten(*in);
|
||||
if (quantize_type == std::string("abs_max")) {
|
||||
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
|
||||
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
|
||||
scale_out.device(dev) = raw_in.abs().maximum();
|
||||
scale = scale_out(0);
|
||||
|
||||
auto& device_ctx = context.template device_context<DeviceContext>();
|
||||
auto* scale_list = context.Output<framework::Tensor>("OutScales");
|
||||
math::SetConstant<DeviceContext, T> scalar;
|
||||
scale_list->mutable_data<T>(context.GetPlace());
|
||||
scalar(device_ctx, scale_list, static_cast<T>(0));
|
||||
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
|
||||
iter->mutable_data<T>(context.GetPlace());
|
||||
scalar(device_ctx, iter, static_cast<T>(0));
|
||||
} else if (quantize_type == std::string("range_abs_max")) {
|
||||
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
|
||||
if (is_test) {
|
||||
scale = moving_scale->data<T>()[0];
|
||||
} else {
|
||||
auto* it = context.Input<framework::Tensor>("InCurrentIter");
|
||||
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
|
||||
const int* last_iter = it->data<int>();
|
||||
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
|
||||
auto* scale_list = context.Output<framework::Tensor>("OutScales");
|
||||
auto* saving_scale =
|
||||
context.Output<framework::Tensor>("OutMovingScale");
|
||||
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
|
||||
scale_out.device(dev) = raw_in.abs().maximum();
|
||||
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
|
||||
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
|
||||
current_iter[0]);
|
||||
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
|
||||
(*current_iter) = (*last_iter) + 1;
|
||||
}
|
||||
} else if (quantize_type == std::string("moving_average_abs_max")) {
|
||||
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
|
||||
if (is_test) {
|
||||
scale = moving_scale->data<T>()[0];
|
||||
} else {
|
||||
auto* saving_scale =
|
||||
context.Output<framework::Tensor>("OutMovingScale");
|
||||
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
|
||||
scale_out.device(dev) = raw_in.abs().maximum();
|
||||
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
|
||||
scale = FindMovingAverageAbsMmax(
|
||||
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
|
||||
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
Transform<DeviceContext> trans;
|
||||
trans(context.template device_context<DeviceContext>(), in->data<T>(),
|
||||
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()),
|
||||
ClipFunctor<T>(-scale, scale));
|
||||
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
|
||||
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
|
||||
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1 +0,0 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1 @@
|
||||
../dense/convert_protobin.sh
|
@ -1 +0,0 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1 @@
|
||||
../dense/convert_protobin.sh
|
@ -1 +0,0 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestFakeQuantizeOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "fake_quantize"
|
||||
self.attrs = {
|
||||
'bit_length': 8,
|
||||
'quantize_type': 'abs_max',
|
||||
'window_size': 10000
|
||||
}
|
||||
self.inputs = {
|
||||
'X': np.random.random((10, 10)).astype("float32"),
|
||||
'InScales': np.zeros(self.attrs['window_size']).astype("float32"),
|
||||
'InCurrentIter': np.zeros(1).astype("float32"),
|
||||
'InMovingScale': np.zeros(1).astype("float32")
|
||||
}
|
||||
self.scale = {
|
||||
'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32")
|
||||
}
|
||||
self.outputs = {
|
||||
'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * (
|
||||
(1 << (self.attrs['bit_length'] - 1)) - 1)),
|
||||
'OutScales': np.zeros(self.attrs['window_size']).astype("float32"),
|
||||
'OutMovingScale':
|
||||
np.array([self.scale['abs_max']]).astype("float32"),
|
||||
'OutCurrentIter': np.zeros(1).astype("float32")
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue