@ -17,12 +17,14 @@ from __future__ import print_function
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					import  unittest 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					import  numpy  as  np 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  op_test  import  OpTest 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					import  paddle . fluid . core  as  core 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOp ( OpTest ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  setUp ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . init_axis ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        x  =  np . random . random ( ( 2 ,  3 ,  4 ,  5 ,  10 ) ) . astype ( " float32 " ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . init_datatype ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        x  =  np . random . random ( ( 2 ,  3 ,  4 ,  5 ,  10 ) ) . astype ( self . dtype ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  self . axis  <  0 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            self . axis  =  self . axis  +  len ( x . shape ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . indices  =  np . argsort ( x ,  kind = ' quicksort ' ,  axis = self . axis ) 
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -35,6 +37,9 @@ class TestArgsortOp(OpTest):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  - 1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_datatype ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . dtype  =  " float32 " 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  test_check_output ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . check_output ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -49,10 +54,54 @@ class TestArgsortOpAxis1(TestArgsortOp):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpAxis2 ( TestArgsortOp ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  2 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpAxisNeg1 ( TestArgsortOp ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  - 1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpAxisNeg2 ( TestArgsortOp ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  - 2 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpFP16 ( TestArgsortOp ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_datatype ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  core . is_compiled_with_cuda ( ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            self . dtype  =  ' float16 ' 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  test_check_output ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        pass 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  test_check_output_with_place ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  core . is_compiled_with_cuda ( ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            place  =  core . CUDAPlace ( 0 ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            self . check_output_with_place ( place ,  atol = 1e-5 ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpFP16Axis0 ( TestArgsortOpFP16 ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  0 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpFP16Axis2 ( TestArgsortOpFP16 ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  2 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpFP16AxisNeg2 ( TestArgsortOpFP16 ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  - 2 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  TestArgsortOpFP16Axis4Neg4 ( TestArgsortOpFP16 ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  init_axis ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . axis  =  - 4 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					if  __name__  ==  " __main__ " : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    unittest . main ( )