|  |  | @ -63,6 +63,7 @@ class TestSpliteSelectedRows(unittest.TestCase): | 
			
		
	
		
		
			
				
					
					|  |  |  |         # expected output selected rows |  |  |  |         # expected output selected rows | 
			
		
	
		
		
			
				
					
					|  |  |  |         expected_out0_rows = [0, 4] |  |  |  |         expected_out0_rows = [0, 4] | 
			
		
	
		
		
			
				
					
					|  |  |  |         expected_out1_rows = [0, 2] |  |  |  |         expected_out1_rows = [0, 2] | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         expected_out2_rows = [] | 
			
		
	
		
		
			
				
					
					|  |  |  |         expected_out4_rows = [0] |  |  |  |         expected_out4_rows = [0] | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         op = Operator( |  |  |  |         op = Operator( | 
			
		
	
	
		
		
			
				
					|  |  | @ -75,6 +76,7 @@ class TestSpliteSelectedRows(unittest.TestCase): | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.assertEqual(outs[0].rows(), expected_out0_rows) |  |  |  |         self.assertEqual(outs[0].rows(), expected_out0_rows) | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.assertEqual(outs[1].rows(), expected_out1_rows) |  |  |  |         self.assertEqual(outs[1].rows(), expected_out1_rows) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.assertEqual(outs[2].rows(), expected_out2_rows) | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.assertEqual(outs[4].rows(), expected_out4_rows) |  |  |  |         self.assertEqual(outs[4].rows(), expected_out4_rows) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.assertEqual(outs[0].height(), height_sections[0]) |  |  |  |         self.assertEqual(outs[0].height(), height_sections[0]) | 
			
		
	
	
		
		
			
				
					|  |  | @ -84,6 +86,9 @@ class TestSpliteSelectedRows(unittest.TestCase): | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1]) |  |  |  |         self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1]) | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1]) |  |  |  |         self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1]) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.assertEqual(outs[2].numel(), 0) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.assertEqual(outs[3].numel(), 0) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     def check_grad_with_place(self, place): |  |  |  |     def check_grad_with_place(self, place): | 
			
		
	
		
		
			
				
					
					|  |  |  |         scope = core.Scope() |  |  |  |         scope = core.Scope() | 
			
		
	
		
		
			
				
					
					|  |  |  |         height = 10 |  |  |  |         height = 10 | 
			
		
	
	
		
		
			
				
					|  |  | 
 |