@ -17,6 +17,7 @@ limitations under the License. */
# include <gtest/gtest.h>
# include <set>
# include <string>
# include <vector>
# include "paddle/fluid/framework/attribute.h"
# include "paddle/fluid/framework/operator.h"
@ -61,12 +62,12 @@ TEST(Prune, one_operator) {
f : : proto : : ProgramDesc pruned ;
std : : set < std : : string > feed_var_names = { } ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
PADDLE_ENFORCE _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 0 ) ;
EXPECT _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 0 ) ;
feed_var_names . insert ( " a " ) ;
pdesc - > mutable_blocks ( 0 ) - > mutable_ops ( 0 ) - > set_is_target ( true ) ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
PADDLE_ENFORCE _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 1 ) ;
EXPECT _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 1 ) ;
}
TEST ( Prune , forward ) {
@ -88,7 +89,7 @@ TEST(Prune, forward) {
f : : proto : : ProgramDesc pruned ;
pdesc - > mutable_blocks ( 0 ) - > mutable_ops ( i ) - > set_is_target ( true ) ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
PADDLE_ENFORCE _EQ( pruned . blocks ( 0 ) . ops_size ( ) , i + 1 ) ;
EXPECT _EQ( pruned . blocks ( 0 ) . ops_size ( ) , i + 1 ) ;
}
}
@ -111,7 +112,7 @@ TEST(Prune, multi_input_op) {
f : : proto : : ProgramDesc pruned ;
std : : set < std : : string > feed_var_names = { " a0 " , " a1 " , " a2 " } ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
PADDLE_ENFORCE _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 4 ) ;
EXPECT _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 4 ) ;
}
TEST ( Prune , multi_output_op ) {
@ -131,7 +132,7 @@ TEST(Prune, multi_output_op) {
f : : proto : : ProgramDesc pruned ;
std : : set < std : : string > feed_var_names = { " a " } ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
PADDLE_ENFORCE _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 2 ) ;
EXPECT _EQ( pruned . blocks ( 0 ) . ops_size ( ) , 2 ) ;
}
TEST ( Prune , multi_target ) {
@ -152,5 +153,35 @@ TEST(Prune, multi_target) {
f : : proto : : ProgramDesc pruned ;
std : : set < std : : string > feed_var_names = { " a " } ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
PADDLE_ENFORCE_EQ ( pruned . blocks ( 0 ) . ops_size ( ) , 3 ) ;
EXPECT_EQ ( pruned . blocks ( 0 ) . ops_size ( ) , 3 ) ;
}
TEST ( Prune , recurrrent_op ) {
f : : ProgramDesc program ;
f : : BlockDesc * block = program . MutableBlock ( 0 ) ;
f : : BlockDesc * sub_block = program . AppendBlock ( * block ) ;
AddOp ( " one_two " , { { " input " , { " a " } } } , { { " output " , { " b " , " c " } } } ,
f : : AttributeMap { } , block ) ;
std : : vector < std : : string > state_var_name ( 1 , " y " ) ;
AddOp ( " recurrent " , { { " input " , { " b " , " c " } } } , { { " output " , { " b1, c1 " } } } ,
{ { " ex_states " , state_var_name } ,
{ " states " , state_var_name } ,
{ " sub_block " , sub_block } } ,
block ) ;
EXPECT_TRUE ( sub_block ! = nullptr ) ;
AddOp ( " rnn_memory_helper " , { { " input " , { " x " } } } , { { " output " , { " y " } } } ,
f : : AttributeMap { } , sub_block ) ;
f : : proto : : ProgramDesc * pdesc = program . Proto ( ) ;
pdesc - > mutable_blocks ( 0 ) - > mutable_ops ( 1 ) - > set_is_target ( true ) ;
f : : proto : : ProgramDesc pruned ;
std : : set < std : : string > feed_var_names = { " a " } ;
f : : Prune ( * pdesc , feed_var_names , & pruned ) ;
EXPECT_EQ ( pruned . blocks_size ( ) , 2 ) ;
EXPECT_EQ ( pruned . blocks ( 0 ) . ops_size ( ) , 2 ) ;
EXPECT_EQ ( pruned . blocks ( 1 ) . ops_size ( ) , 1 ) ;
}