@ -14,13 +14,11 @@
# include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
# include <algorithm>
# include <memory>
# include <string>
namespace paddle {
namespace framework {
class LoDTensor ;
} // namespace framework
} // namespace paddle
# include <unordered_set>
# include <vector>
namespace paddle {
namespace framework {
@ -78,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
any_op2_desc - > Flush ( ) ;
auto dequant_type = quant_dequant_op - > Op ( ) - > Type ( ) ;
auto quantized_op_type = any_op2_desc - > Type ( ) ;
// get weight tensor
auto * weight_tensor =
scope - > GetVar ( quant_dequant_op_x - > Name ( ) ) - > GetMutable < LoDTensor > ( ) ;
auto w_dims = weight_tensor - > dims ( ) ;
float * quantized_weight_data =
weight_tensor - > mutable_data < float > ( platform : : CPUPlace ( ) ) ;
// Get weight scale
if ( dequant_type = = " fake_channel_wise_quantize_dequantize_abs_max " ) {
@ -93,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
paddle : : platform : : is_cpu_place ( channel_scale_tensor . place ( ) ) ,
platform : : errors : : InvalidArgument (
" Channel scale tensor's place should be CPU. " ) ) ;
const float * channel_scale_data = channel_scale_tensor . data < float > ( ) ;
for ( int i = 0 ; i < channel_scale_tensor . numel ( ) ; i + + ) {
weight_scale . push_back ( range / channel_scale_data [ i ] ) ;
// compute the channel wise abs max of the weight tensor
int quant_axis =
BOOST_GET_CONST ( int , quant_dequant_op - > Op ( ) - > GetAttr ( " quant_axis " ) ) ;
PADDLE_ENFORCE_EQ ( quant_axis = = 0 | | quant_axis = = 1 , true ,
platform : : errors : : InvalidArgument (
" 'quant_axis' should be 0 or 1, but "
" the received is %d " ,
quant_axis ) ) ;
const int64_t channel = w_dims [ quant_axis ] ;
weight_scale . resize ( channel , 0 ) ;
if ( quant_axis = = 0 ) {
const int64_t channel_size = weight_tensor - > numel ( ) / channel ;
for ( int64_t i = 0 ; i < channel ; i + + ) {
auto * start = quantized_weight_data + i * channel_size ;
for ( int64_t j = 0 ; j < channel_size ; j + + ) {
weight_scale [ i ] = std : : max ( std : : abs ( start [ j ] ) , weight_scale [ i ] ) ;
}
}
} else if ( quant_axis = = 1 ) {
const int64_t step_i = weight_tensor - > numel ( ) / w_dims [ 0 ] ;
const int64_t step_j = weight_tensor - > numel ( ) / ( w_dims [ 0 ] * w_dims [ 1 ] ) ;
for ( int64_t i = 0 ; i < w_dims [ 0 ] ; i + + ) {
for ( int64_t j = 0 ; j < w_dims [ 1 ] ; j + + ) {
auto * start = quantized_weight_data + i * step_i + j * step_j ;
float abs_max = 0 ;
for ( int64_t k = 0 ; k < step_j ; k + + ) {
abs_max = std : : max ( std : : abs ( start [ k ] ) , abs_max ) ;
}
weight_scale [ j ] = std : : max ( weight_scale [ j ] , abs_max ) ;
}
}
}
for ( int i = 0 ; i < channel ; i + + ) {
PADDLE_ENFORCE_NE ( weight_scale [ i ] , 0 ,
platform : : errors : : InvalidArgument (
" Weight scale should be nonzero, but get zero. " ) ) ;
weight_scale [ i ] = range / weight_scale [ i ] ;
}
} else {
auto scale_name = quant_dequant_op_outscale - > Name ( ) ;
const LoDTensor & scale_tensor =
scope - > GetVar ( scale_name ) - > Get < LoDTensor > ( ) ;
const float * scale_data = scale_tensor . data < float > ( ) ;
weight_scale . push_back ( ( range * range ) / scale_data [ 0 ] / range ) ;
// compute the abs max of the weight tensor
float abs_max_weight = 0. ;
for ( int j = 0 ; j < weight_tensor - > numel ( ) ; j + + ) {
abs_max_weight =
std : : max ( abs_max_weight , std : : abs ( quantized_weight_data [ j ] ) ) ;
}
PADDLE_ENFORCE_NE ( abs_max_weight , 0 ,
platform : : errors : : InvalidArgument (
" Weight scale should be nonzero, but get zero " ) ) ;
weight_scale . push_back ( ( range * range ) / abs_max_weight / range ) ;
}
nodes2rm . insert ( quant_dequant_op_outscale ) ;
// perform quantize dequantize operations
auto * weight_tensor =
scope - > GetVar ( quant_dequant_op_x - > Name ( ) ) - > GetMutable < LoDTensor > ( ) ;
auto w_dims = weight_tensor - > dims ( ) ;
float * quantized_weight_data =
weight_tensor - > mutable_data < float > ( platform : : CPUPlace ( ) ) ;
// If quantized op is fc, weight scale size = 1;
// If quantized op is not channel wise, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if ( dequant_type = = " fake_quantize_dequantize_abs_max " ) {
@ -122,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
" %s op weight dequantized by [fake_quantize_dequantize_max_abs] "
" requires weight scale size = 1, but got %d. " ,
quantized_op_type , weight_scale . size ( ) ) ) ;
PADDLE_ENFORCE_NE ( weight_scale [ 0 ] , 0 ,
platform : : errors : : InvalidArgument (
" Weight scale should be nonzero, but get zero " ) ) ;
for ( int j = 0 ; j < weight_tensor - > numel ( ) ; j + + ) {
// quantized
quantized_weight_data [ j ] = quantized_weight_data [ j ] * weight_scale [ 0 ] ;