@ -50,6 +50,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if ( type = = " concat " ) {
op - > SetInput ( " X " , inputs ) ;
op - > SetOutput ( " Out " , outputs ) ;
} else if ( type = = " fc " ) {
op - > SetInput ( " Input " , { inputs [ 0 ] } ) ;
PADDLE_ENFORCE_EQ ( inputs . size ( ) , 2UL ,
platform : : errors : : InvalidArgument (
" The fc inputs should contain input and weights, but "
" now the size of inputs is %d " ,
inputs . size ( ) ) ) ;
op - > SetInput ( " W " , { inputs [ 1 ] } ) ;
op - > SetOutput ( " Out " , outputs ) ;
}
}
@ -176,6 +185,36 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
return prog ;
}
// a->fc->b
// b->Dequant1->c
// c->Concat1->d
ProgramDesc BuildFcDequantConcatProgramDesc ( bool use_mkldnn , float scale_out ,
float scale ) {
ProgramDesc prog ;
for ( auto & v : variable_names ) {
prog . MutableBlock ( 0 ) - > Var ( v ) ;
}
SetOp ( & prog , " fc " , " Fc1 " , { " a " , " w1 " } , { " b " } , use_mkldnn , scale_out ) ;
SetOp ( & prog , " dequantize " , " Dequant1 " , { " b " } , { " c " } , use_mkldnn , scale ) ;
SetOp ( & prog , " concat " , " Concat1 " , { " c " } , { " d " } , use_mkldnn ) ;
return prog ;
}
// a->fc->b
// b->Dequant1->c
// b->concat->d
ProgramDesc BuildFcDequantFcProgramDesc ( bool use_mkldnn , float scale_out ,
float scale ) {
ProgramDesc prog ;
for ( auto & v : variable_names ) {
prog . MutableBlock ( 0 ) - > Var ( v ) ;
}
SetOp ( & prog , " fc " , " Fc1 " , { " a " , " w1 " } , { " b " } , use_mkldnn , scale_out ) ;
SetOp ( & prog , " dequantize " , " Dequant1 " , { " b " } , { " c " } , use_mkldnn , scale ) ;
SetOp ( & prog , " concat " , " Concat1 " , { " b " } , { " d " } , use_mkldnn ) ;
return prog ;
}
// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
@ -261,6 +300,23 @@ void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
}
}
// check requant_op scales
void IsForceFp32OutputTest ( const ProgramDesc & prog , std : : string op_type ,
bool target_is_force_fp32_output ) {
std : : unique_ptr < ir : : Graph > graph ( new ir : : Graph ( prog ) ) ;
PrepareGraph ( & graph , prog ) ;
RegisterPass ( & graph ) ;
for ( auto * node : graph - > Nodes ( ) ) {
if ( node - > IsOp ( ) & & node - > Op ( ) - > Type ( ) = = op_type ) {
bool is_force_fp32_output =
node - > Op ( ) - > GetAttrIfExists < bool > ( " force_fp32_output " ) ;
EXPECT_EQ ( is_force_fp32_output , target_is_force_fp32_output ) ;
}
}
}
// From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2
TEST ( CpuQuantizeSquashPass , equal_scales ) {
@ -362,8 +418,12 @@ TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
auto remove_nodes = 2 ;
CountNodeTest ( BuildConvDequantConcatProgramDesc ( use_mkldnn , scale_out , scale ) ,
remove_nodes ) ;
IsForceFp32OutputTest (
BuildConvDequantConcatProgramDesc ( use_mkldnn , scale_out , scale ) , " conv2d " ,
true ) ;
}
// If there are more than one op after conv->dequantize, do not fuse
TEST ( CpuQuantizeSquashPass , conv_dequant_more_than_one_op_after_conv ) {
auto scale_out = 1.0f ;
auto scale = 1.2345f ;
@ -372,6 +432,39 @@ TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
auto remove_nodes = 0 ;
CountNodeTest ( BuildConvDequantConvProgramDesc ( use_mkldnn , scale_out , scale ) ,
remove_nodes ) ;
IsForceFp32OutputTest (
BuildConvDequantConvProgramDesc ( use_mkldnn , scale_out , scale ) , " conv2d " ,
false ) ;
}
// from
// a->fc->b->Dequant1->c->Concat1->d
// to
// a->fc->c->Concat->d
TEST ( CpuQuantizeSquashPass , fc_dequant_only_one_output ) {
auto scale_out = 1.0f ;
auto scale = 1.2345f ;
auto use_mkldnn = true ;
// remove 2 nodes: b, Dequant1
auto remove_nodes = 2 ;
CountNodeTest ( BuildFcDequantConcatProgramDesc ( use_mkldnn , scale_out , scale ) ,
remove_nodes ) ;
IsForceFp32OutputTest (
BuildFcDequantConcatProgramDesc ( use_mkldnn , scale_out , scale ) , " fc " ,
true ) ;
}
// If there are more than one op after fc->dequantize, do not fuse
TEST ( CpuQuantizeSquashPass , fc_dequant_more_than_one_op_after_dequant ) {
auto scale_out = 1.0f ;
auto scale = 1.2345f ;
auto use_mkldnn = true ;
// nothing change
auto remove_nodes = 0 ;
CountNodeTest ( BuildFcDequantFcProgramDesc ( use_mkldnn , scale_out , scale ) ,
remove_nodes ) ;
IsForceFp32OutputTest (
BuildFcDequantFcProgramDesc ( use_mkldnn , scale_out , scale ) , " fc " , false ) ;
}
} // namespace ir