|  |  |  | @ -402,12 +402,9 @@ def parse_op_info(op_name): | 
			
		
	
		
			
				
					|  |  |  |  |     op_proto = OpProtoHolder.instance().get_op_proto(op_name) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     in_names = [x.name for x in op_proto.inputs] | 
			
		
	
		
			
				
					|  |  |  |  |     assert len(op_proto.outputs) == 1 | 
			
		
	
		
			
				
					|  |  |  |  |     out_name = op_proto.outputs[0].name | 
			
		
	
		
			
				
					|  |  |  |  |     out_names = [x.name for x in op_proto.outputs] | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     # TODO(Aurelius84): parse necessary out_dtype  of custom op | 
			
		
	
		
			
				
					|  |  |  |  |     out_infos = {out_name: ['float32']} | 
			
		
	
		
			
				
					|  |  |  |  |     return in_names, out_infos | 
			
		
	
		
			
				
					|  |  |  |  |     return in_names, out_names | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def _import_module_from_library(module_name, build_directory, verbose=False): | 
			
		
	
	
		
			
				
					|  |  |  | @ -450,13 +447,10 @@ def _generate_python_module(module_name, | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def _custom_api_content(op_name): | 
			
		
	
		
			
				
					|  |  |  |  |     params_str, ins_str = _get_api_inputs_str(op_name) | 
			
		
	
		
			
				
					|  |  |  |  |     params_str, ins_str, outs_str = _get_api_inputs_str(op_name) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     API_TEMPLATE = textwrap.dedent(""" | 
			
		
	
		
			
				
					|  |  |  |  |         from paddle.fluid.layer_helper import LayerHelper | 
			
		
	
		
			
				
					|  |  |  |  |         from paddle.utils.cpp_extension import parse_op_info | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         _, _out_infos = parse_op_info('{op_name}') | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         def {op_name}({inputs}): | 
			
		
	
		
			
				
					|  |  |  |  |             helper = LayerHelper("{op_name}", **locals()) | 
			
		
	
	
		
			
				
					|  |  |  | @ -464,21 +458,22 @@ def _custom_api_content(op_name): | 
			
		
	
		
			
				
					|  |  |  |  |             # prepare inputs and output  | 
			
		
	
		
			
				
					|  |  |  |  |             ins = {ins} | 
			
		
	
		
			
				
					|  |  |  |  |             outs = {{}} | 
			
		
	
		
			
				
					|  |  |  |  |             for out_name in _out_infos: | 
			
		
	
		
			
				
					|  |  |  |  |                 outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]] | 
			
		
	
		
			
				
					|  |  |  |  |             out_names = {out_names} | 
			
		
	
		
			
				
					|  |  |  |  |             for out_name in out_names: | 
			
		
	
		
			
				
					|  |  |  |  |                 # Set 'float32' temporarily, and the actual dtype of output variable will be inferred | 
			
		
	
		
			
				
					|  |  |  |  |                 # in runtime. | 
			
		
	
		
			
				
					|  |  |  |  |                 outs[out_name] = helper.create_variable(dtype='float32') | 
			
		
	
		
			
				
					|  |  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |  |             helper.append_op(type="{op_name}", inputs=ins, outputs=outs) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |             res = list(outs.values())[0] | 
			
		
	
		
			
				
					|  |  |  |  |             if len(res) == 1: | 
			
		
	
		
			
				
					|  |  |  |  |                 return res[0] | 
			
		
	
		
			
				
					|  |  |  |  |             else: | 
			
		
	
		
			
				
					|  |  |  |  |                 return res | 
			
		
	
		
			
				
					|  |  |  |  |             res = [outs[out_name] for out_name in out_names] | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |             return res[0] if len(res)==1 else res | 
			
		
	
		
			
				
					|  |  |  |  |             """).lstrip() | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     # generate python api file | 
			
		
	
		
			
				
					|  |  |  |  |     api_content = API_TEMPLATE.format( | 
			
		
	
		
			
				
					|  |  |  |  |         op_name=op_name, inputs=params_str, ins=ins_str) | 
			
		
	
		
			
				
					|  |  |  |  |         op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     return api_content | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -509,13 +504,15 @@ def _get_api_inputs_str(op_name): | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     Returns string of api parameters and inputs dict. | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     in_names, _ = parse_op_info(op_name) | 
			
		
	
		
			
				
					|  |  |  |  |     in_names, out_names = parse_op_info(op_name) | 
			
		
	
		
			
				
					|  |  |  |  |     # e.g: x, y, z | 
			
		
	
		
			
				
					|  |  |  |  |     params_str = ','.join([p.lower() for p in in_names]) | 
			
		
	
		
			
				
					|  |  |  |  |     # e.g: {'X': x, 'Y': y, 'Z': z} | 
			
		
	
		
			
				
					|  |  |  |  |     ins_str = "{%s}" % ','.join( | 
			
		
	
		
			
				
					|  |  |  |  |         ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) | 
			
		
	
		
			
				
					|  |  |  |  |     return params_str, ins_str | 
			
		
	
		
			
				
					|  |  |  |  |     # e.g: ['Out', 'Index'] | 
			
		
	
		
			
				
					|  |  |  |  |     outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names]) | 
			
		
	
		
			
				
					|  |  |  |  |     return params_str, ins_str, outs_str | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def _write_setup_file(name, | 
			
		
	
	
		
			
				
					|  |  |  | 
 |