Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into unsqueeze_op
commit
fbef49e772
@ -0,0 +1,110 @@
|
||||
Fixed-point quantization uses lower bits, for example, 2-bit, 3-bit or 8-bit fixed point to represent weights and activations, which usually are in singe-precision float-point with 32 bits. The fixed-point representation has advantages in reducing memory bandwidth, lowering power consumption and computational resources as well as the model storage requirements. It is especially important for the inference in embedded-device deployment.
|
||||
|
||||
According to some experiments, the apporach to quantize the model trained in float point directly works effectively on the large models, like the VGG model having many parameters. But the accuracy drops a lot for the small model. In order to improve the tradeoff between accuracy and latency, many quantized training apporaches are proposed.
|
||||
|
||||
This document is to design a quantized training framework on Fluid. The first part will introduce how to quantize, The second part will describe the quantized training framework. The last part will illustrate how to calculate the quantization scale.
|
||||
|
||||
|
||||
### How to quantize
|
||||
|
||||
There are many ways to quantize the float value to fixed-point value. For example:
|
||||
|
||||
$$ r = min(max(x, a), b)$$
|
||||
$$ s = \frac{b - a}{n - 1} $$
|
||||
$$ q = \left \lfloor \frac{r - a}{s} \right \rceil $$
|
||||
|
||||
where, $x$ is the float value to be quantized, $[a, b]$ is the quantization range, $a$ is the minimum value and $b$ is the maximal value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. If the quantization level is $k$, $n$ is $2^k$, for example, $k$ is 8 and $n$ is 256. $q$ is the quantized integer.
|
||||
|
||||
|
||||
The quantization we applied is parameterized by the number of quantization levels and maximum absolute value:
|
||||
|
||||
$$ M = max(abs(x)) $$
|
||||
$$ q = \left \lfloor \frac{x}{M} * (n - 1) \right \rceil $$
|
||||
|
||||
where, $x$ is the float value to be quantized, $M$ is maximum absolute value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. For 8 bit quantization, $n=2^{8}=256$. $q$ is the quantized integer.
|
||||
|
||||
|
||||
Wether the *min-max* quantization or *max-abs* quantization, they also can be represent:
|
||||
|
||||
$q = scale * r + b$
|
||||
|
||||
We call *min-max*, *max-abs* as the quantization arguments, also call them quantization scale or quantization range.
|
||||
|
||||
|
||||
How to calculate the quantization scale (or maximum absolute value) for inference will be described in the last part.
|
||||
|
||||
|
||||
### Training Framework
|
||||
|
||||
#### Forward pass
|
||||
|
||||
The forward pass is simulated quantization, see Figure 1.
|
||||
|
||||
The training framework is as following figure.
|
||||
|
||||
<p align="center">
|
||||
<img src="quantization_forward.png" width="300" height="340"><br/>
|
||||
Figure 1. Forward in training with simulated quantization.
|
||||
</p>
|
||||
|
||||
- Firstly, both input and weight will be quantized to 8-bit integers.
|
||||
- Second, do the multiplication (or convolution) operation with integers.
|
||||
- Third, dequantize the multiplication (or convolution) results to 32-bit float point.
|
||||
- Finally, do bias-addition in float type of 32 bit. Here, the bias is not quantized.
|
||||
|
||||
For general matrix multiplication (GEMM), quantize for $X$ and $W$:
|
||||
|
||||
$$ X_q = \left \lfloor \frac{X}{X_m} * (n - 1) \right \rceil $$
|
||||
$$ W_q = \left \lfloor \frac{W}{W_m} * (n - 1) \right \rceil $$
|
||||
|
||||
Do GEMM:
|
||||
|
||||
$$ Y = X_q * W_q $$
|
||||
|
||||
|
||||
Dequantize $Y$:
|
||||
|
||||
$$
|
||||
\begin{align}
|
||||
Y_{dq} &=\frac{Y}{(n - 1) * (n - 1)} * X_m * W_m \\\
|
||||
&=\frac{X_q * W_q}{(n - 1) * (n - 1)} * X_m * W_m \\\
|
||||
&=(\frac{X_q}{n - 1} * X_m) * (\frac{W_q}{n - 1} * W_m)
|
||||
\end{align}
|
||||
$$
|
||||
|
||||
From these formulas, dequantization also can be moved before GEMM, do dequantization for $Xq$ and $Wq$ at first, then do GEMM. The forward workflow in training is equivalent to following framework.
|
||||
|
||||
<p align="center">
|
||||
<img src="quantization_equivalent_forward.png" width="300" height="330"><br/>
|
||||
Figure 2. Equivalent forward in training with simulated quantization.
|
||||
</p>
|
||||
|
||||
We use this equivalent workflow in the training. In our desigin, there is a quantization transpiler to insert the quantization operator and the de-quantization operator in the Fluid `ProgramDesc`. Since the outputs of quantization and de-quantization operator are still in floating point, they are called faked quantization and de-quantization operator. And the training framework is called simulated quantization.
|
||||
|
||||
#### Backward pass
|
||||
|
||||
See Figure 3. The gradients are calculated by dequantized weights and activations. All inputs and outputs are float point with 32-bit. And in the weight updating process, the gradients will be added to the original weight, not the quantized or dequantized weights.
|
||||
|
||||
<p align="center">
|
||||
<img src="quantization_backward_and_optimization.png"><br/>
|
||||
Figure 3. Backward and weight updating in training with simulated quantization.
|
||||
</p>
|
||||
|
||||
So the quantization transipler will change some inputs of the corresponding backward operators.
|
||||
|
||||
### How to calculate quantization scale
|
||||
|
||||
There are two strategies to calculate quantization scale, we call them dynamic and static strategy. The dynamic strategy calculates the quantization scale value each iteration. The static strategy keeps the quantization scale for different inputs.
|
||||
|
||||
For weights, we apply the dynamic strategy in the training, that is to say, the quantization scale will be recalculated during each iteration until the traning is finished.
|
||||
|
||||
For activations, the quantization scales are estimated during training, then used in inference. There are several different ways to estimate them:
|
||||
|
||||
|
||||
1. Calculate the mean of maximum absolute during a window.
|
||||
2. Calculate the max of maximum absolute during a window.
|
||||
3. Calculate the running mean of maximum absolute during a window, as follows:
|
||||
|
||||
$$ Vt = (1 - k) * V + k * V_{t-1} $$
|
||||
|
||||
where, $V$ is the maximum absolute value of current batch, $Vt$ is the running mean value. $k$ is a factor, such as 0.9.
|
After Width: | Height: | Size: 42 KiB |
After Width: | Height: | Size: 32 KiB |
After Width: | Height: | Size: 27 KiB |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,112 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fake_quantize_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FakeQuantizeOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
FakeQuantizeOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of FakeQuantizeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FakeQuantizeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
|
||||
"OutMovingScale(Out) of FakeQuantizeOp should not be null");
|
||||
// if (ctx->HasInput("InMovingScale")) {
|
||||
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
|
||||
//}
|
||||
// if (ctx->HasInput("InScales")) {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
|
||||
"OutScales(Out) of FakeQuantizeOp should not be null");
|
||||
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
|
||||
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
|
||||
// ctx->Outputs("OutScales")[0],
|
||||
// "Mean and MeanOut should share the same memory");
|
||||
//}
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) Input tensor of scale operator.");
|
||||
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddInput("InMovingScale", "Last scale, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddInput("InCurrentIter",
|
||||
"Last iteration number, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
|
||||
AddOutput("OutScales",
|
||||
"(Tensor) scale buffer, used in static quantization.")
|
||||
.AsDispensable();
|
||||
AddOutput("OutMovingScale", " Current scale");
|
||||
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
|
||||
AddAttr<std::string>("quantize_type",
|
||||
"(string, default abs_max)"
|
||||
"The scaling tpe of the quantize operator.")
|
||||
.SetDefault("abs_max");
|
||||
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
|
||||
AddAttr<int>("bit_length", "(int, default 8)")
|
||||
.SetDefault(8)
|
||||
.AddCustomChecker([](const int &bit_length) {
|
||||
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
|
||||
"'bit_length' should be between 1 and 16.");
|
||||
});
|
||||
AddAttr<bool>("is_test", "").SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
FakeQuantize operator
|
||||
|
||||
quantize_type = abs_max:
|
||||
|
||||
$$scale = max(abs(x))$$
|
||||
|
||||
quantize_type = range_abs_max:
|
||||
|
||||
$$scale = max(max(abs(x)), history_abs_max)$$
|
||||
|
||||
quantize_type = moving_average_abs_max:
|
||||
|
||||
$$scale = 0.1*scale+0.9*new_abs_max)$$
|
||||
|
||||
$$Out = scale*X$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fake_quantize,
|
||||
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,155 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/clip_op.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using platform::Transform;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FakeQuantizeKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
T FindAbsMax(framework::Tensor* in, int n) const {
|
||||
T* p = in->mutable_data<T>(platform::CPUPlace());
|
||||
T abs_max = (T)0.00000001;
|
||||
for (int i = 0; i < n; i++) {
|
||||
T tmp = fabs(p[i]);
|
||||
if (tmp > abs_max) abs_max = tmp;
|
||||
}
|
||||
return T(abs_max);
|
||||
}
|
||||
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
|
||||
const T& cur_scale, int window_size,
|
||||
int current_iter) const {
|
||||
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
|
||||
T remove_tmp = sl[current_iter];
|
||||
sl[current_iter] = cur_scale;
|
||||
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
|
||||
if (max_scale < cur_scale) {
|
||||
max_scale = cur_scale;
|
||||
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
|
||||
int size = (current_iter > window_size) ? window_size : current_iter;
|
||||
max_scale = T(FindAbsMax(scale_list, size));
|
||||
}
|
||||
return max_scale;
|
||||
}
|
||||
|
||||
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
|
||||
framework::Tensor* out_scale,
|
||||
const T& cur_scale) const {
|
||||
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
|
||||
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
|
||||
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
|
||||
return T(outs[0]);
|
||||
}
|
||||
|
||||
virtual void Compute(const framework::ExecutionContext& context) const {
|
||||
auto* tensor = context.Output<framework::Tensor>("Out");
|
||||
auto* in = context.Input<framework::Tensor>("X");
|
||||
const bool is_test = context.Attr<bool>("is_test");
|
||||
tensor->mutable_data<T>(in->place());
|
||||
|
||||
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
|
||||
oms_tensor->mutable_data<T>(in->place());
|
||||
|
||||
auto quantize_type =
|
||||
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
|
||||
if (quantize_type == std::string("range_abs_max")) {
|
||||
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
|
||||
oss_tensor->mutable_data<T>(
|
||||
context.Input<framework::Tensor>("InScales")->place());
|
||||
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
|
||||
oci_tensor->mutable_data<T>(
|
||||
context.Input<framework::Tensor>("InCurrentIter")->place());
|
||||
}
|
||||
|
||||
T scale = static_cast<T>(1);
|
||||
int window_size = context.Attr<int>("window_size");
|
||||
int bit_length = context.Attr<int>("bit_length");
|
||||
int bin_cnt = std::pow(2, bit_length - 1) - 1;
|
||||
|
||||
auto& dev =
|
||||
*context.template device_context<DeviceContext>().eigen_device();
|
||||
auto raw_in = framework::EigenVector<T>::Flatten(*in);
|
||||
if (quantize_type == std::string("abs_max")) {
|
||||
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
|
||||
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
|
||||
scale_out.device(dev) = raw_in.abs().maximum();
|
||||
scale = scale_out(0);
|
||||
|
||||
auto& device_ctx = context.template device_context<DeviceContext>();
|
||||
auto* scale_list = context.Output<framework::Tensor>("OutScales");
|
||||
math::SetConstant<DeviceContext, T> scalar;
|
||||
scale_list->mutable_data<T>(context.GetPlace());
|
||||
scalar(device_ctx, scale_list, static_cast<T>(0));
|
||||
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
|
||||
iter->mutable_data<T>(context.GetPlace());
|
||||
scalar(device_ctx, iter, static_cast<T>(0));
|
||||
} else if (quantize_type == std::string("range_abs_max")) {
|
||||
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
|
||||
if (is_test) {
|
||||
scale = moving_scale->data<T>()[0];
|
||||
} else {
|
||||
auto* it = context.Input<framework::Tensor>("InCurrentIter");
|
||||
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
|
||||
const int* last_iter = it->data<int>();
|
||||
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
|
||||
auto* scale_list = context.Output<framework::Tensor>("OutScales");
|
||||
auto* saving_scale =
|
||||
context.Output<framework::Tensor>("OutMovingScale");
|
||||
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
|
||||
scale_out.device(dev) = raw_in.abs().maximum();
|
||||
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
|
||||
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
|
||||
current_iter[0]);
|
||||
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
|
||||
(*current_iter) = (*last_iter) + 1;
|
||||
}
|
||||
} else if (quantize_type == std::string("moving_average_abs_max")) {
|
||||
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
|
||||
if (is_test) {
|
||||
scale = moving_scale->data<T>()[0];
|
||||
} else {
|
||||
auto* saving_scale =
|
||||
context.Output<framework::Tensor>("OutMovingScale");
|
||||
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
|
||||
scale_out.device(dev) = raw_in.abs().maximum();
|
||||
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
|
||||
scale = FindMovingAverageAbsMmax(
|
||||
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
|
||||
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
Transform<DeviceContext> trans;
|
||||
trans(context.template device_context<DeviceContext>(), in->data<T>(),
|
||||
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()),
|
||||
ClipFunctor<T>(-scale, scale));
|
||||
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
|
||||
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
|
||||
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1 +0,0 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1 @@
|
||||
../dense/convert_protobin.sh
|
@ -1 +0,0 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1 @@
|
||||
../dense/convert_protobin.sh
|
@ -1 +0,0 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1 @@
|
||||
../dense/convert_protobin.sh
|
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestFakeQuantizeOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "fake_quantize"
|
||||
self.attrs = {
|
||||
'bit_length': 8,
|
||||
'quantize_type': 'abs_max',
|
||||
'window_size': 10000
|
||||
}
|
||||
self.inputs = {
|
||||
'X': np.random.random((10, 10)).astype("float32"),
|
||||
'InScales': np.zeros(self.attrs['window_size']).astype("float32"),
|
||||
'InCurrentIter': np.zeros(1).astype("float32"),
|
||||
'InMovingScale': np.zeros(1).astype("float32")
|
||||
}
|
||||
self.scale = {
|
||||
'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32")
|
||||
}
|
||||
self.outputs = {
|
||||
'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * (
|
||||
(1 << (self.attrs['bit_length'] - 1)) - 1)),
|
||||
'OutScales': np.zeros(self.attrs['window_size']).astype("float32"),
|
||||
'OutMovingScale':
|
||||
np.array([self.scale['abs_max']]).astype("float32"),
|
||||
'OutCurrentIter': np.zeros(1).astype("float32")
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap,
|
||||
rpn_negative_overlap, fg_fraction):
|
||||
iou = np.transpose(iou)
|
||||
anchor_to_gt_max = iou.max(axis=1)
|
||||
gt_to_anchor_argmax = iou.argmax(axis=0)
|
||||
gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])]
|
||||
anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0]
|
||||
|
||||
tgt_lbl = np.ones((iou.shape[0], ), dtype=np.int32) * -1
|
||||
tgt_lbl[anchors_with_max_overlap] = 1
|
||||
tgt_lbl[anchor_to_gt_max >= rpn_positive_overlap] = 1
|
||||
|
||||
num_fg = int(fg_fraction * rpn_batch_size_per_im)
|
||||
fg_inds = np.where(tgt_lbl == 1)[0]
|
||||
if len(fg_inds) > num_fg:
|
||||
disable_inds = np.random.choice(
|
||||
fg_inds, size=(len(fg_inds) - num_fg), replace=False)
|
||||
tgt_lbl[disable_inds] = -1
|
||||
fg_inds = np.where(tgt_lbl == 1)[0]
|
||||
|
||||
num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1)
|
||||
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
|
||||
if len(bg_inds) > num_bg:
|
||||
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
|
||||
tgt_lbl[enable_inds] = 0
|
||||
bg_inds = np.where(tgt_lbl == 0)[0]
|
||||
|
||||
loc_index = fg_inds
|
||||
score_index = np.hstack((fg_inds, bg_inds))
|
||||
tgt_lbl = np.expand_dims(tgt_lbl, axis=1)
|
||||
return loc_index, score_index, tgt_lbl
|
||||
|
||||
|
||||
class TestRpnTargetAssignOp(OpTest):
|
||||
def setUp(self):
|
||||
iou = np.random.random((10, 8)).astype("float32")
|
||||
self.op_type = "rpn_target_assign"
|
||||
self.inputs = {'DistMat': iou}
|
||||
self.attrs = {
|
||||
'rpn_batch_size_per_im': 256,
|
||||
'rpn_positive_overlap': 0.95,
|
||||
'rpn_negative_overlap': 0.3,
|
||||
'fg_fraction': 0.25,
|
||||
'fix_seed': True
|
||||
}
|
||||
loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 256, 0.95, 0.3,
|
||||
0.25)
|
||||
self.outputs = {
|
||||
'LocationIndex': loc_index,
|
||||
'ScoreIndex': score_index,
|
||||
'TargetLabel': tgt_lbl,
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestRpnTargetAssignOp2(OpTest):
|
||||
def setUp(self):
|
||||
iou = np.random.random((10, 20)).astype("float32")
|
||||
self.op_type = "rpn_target_assign"
|
||||
self.inputs = {'DistMat': iou}
|
||||
self.attrs = {
|
||||
'rpn_batch_size_per_im': 128,
|
||||
'rpn_positive_overlap': 0.5,
|
||||
'rpn_negative_overlap': 0.5,
|
||||
'fg_fraction': 0.5,
|
||||
'fix_seed': True
|
||||
}
|
||||
loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 128, 0.5, 0.5,
|
||||
0.5)
|
||||
self.outputs = {
|
||||
'LocationIndex': loc_index,
|
||||
'ScoreIndex': score_index,
|
||||
'TargetLabel': tgt_lbl,
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue