| 
						
						
							
								
							
						
						
					 | 
					 | 
					@ -758,6 +758,7 @@ class QuantizationTransformPass(object):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            attrs={
 | 
					 | 
					 | 
					 | 
					            attrs={
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                'bit_length': quant_bits,
 | 
					 | 
					 | 
					 | 
					                'bit_length': quant_bits,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                'quant_axis': quant_axis,
 | 
					 | 
					 | 
					 | 
					                'quant_axis': quant_axis,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                'is_test': self._is_test,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
 | 
					 | 
					 | 
					 | 
					                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            },
 | 
					 | 
					 | 
					 | 
					            },
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            inputs={'X': var_node},
 | 
					 | 
					 | 
					 | 
					            inputs={'X': var_node},
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -1125,7 +1126,7 @@ class QuantizationFreezePass(object):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    self._restore_var(input_arg_name, quantized_param_v)
 | 
					 | 
					 | 
					 | 
					                    self._restore_var(input_arg_name, quantized_param_v)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    self._remove_fake_quant_and_dequant_op(graph, op_node)
 | 
					 | 
					 | 
					 | 
					                    self._remove_fake_quant_and_dequant_op(graph, op_node)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					# Remove all fake dequant op
 | 
					 | 
					 | 
					 | 
					        # Remove all fake dequant op
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        ops = graph.all_op_nodes()
 | 
					 | 
					 | 
					 | 
					        ops = graph.all_op_nodes()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        for op_node in ops:
 | 
					 | 
					 | 
					 | 
					        for op_node in ops:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            op_name = op_node.name()
 | 
					 | 
					 | 
					 | 
					            op_name = op_node.name()
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -1331,16 +1332,25 @@ class QuantizationFreezePass(object):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def _quant(self, x, scale, num_bits, quant_axis):
 | 
					 | 
					 | 
					 | 
					    def _quant(self, x, scale, num_bits, quant_axis):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
 | 
					 | 
					 | 
					 | 
					        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        bnt = (1 << (num_bits - 1)) - 1
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        def _clip(x, scale):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            x[x > scale] = scale
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            x[x < -scale] = -scale
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            return x
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if isinstance(scale, list):
 | 
					 | 
					 | 
					 | 
					        if isinstance(scale, list):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            for i, s in enumerate(scale):
 | 
					 | 
					 | 
					 | 
					            for i, s in enumerate(scale):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                if quant_axis == 0:
 | 
					 | 
					 | 
					 | 
					                if quant_axis == 0:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
 | 
					 | 
					 | 
					 | 
					                    x[i] = _clip(x[i], s)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    x[i] = np.round(x[i] / s * bnt)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                else:
 | 
					 | 
					 | 
					 | 
					                else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    x[:, i] = np.round(x[:, i] / s * (
 | 
					 | 
					 | 
					 | 
					                    x[:, i] = _clip(x[:, i], s)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        (1 << (num_bits - 1)) - 1))
 | 
					 | 
					 | 
					 | 
					                    x[:, i] = np.round(x[:, i] / s * bnt)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            return x
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        else:
 | 
					 | 
					 | 
					 | 
					        else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
 | 
					 | 
					 | 
					 | 
					            x = _clip(x, scale)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            x = np.round(x / scale * bnt)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        return x
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class ConvertToInt8Pass(object):
 | 
					 | 
					 | 
					 | 
					class ConvertToInt8Pass(object):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					 | 
					
 
 |