|
|
|
|
@ -402,12 +402,9 @@ def parse_op_info(op_name):
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(op_name)
|
|
|
|
|
|
|
|
|
|
in_names = [x.name for x in op_proto.inputs]
|
|
|
|
|
assert len(op_proto.outputs) == 1
|
|
|
|
|
out_name = op_proto.outputs[0].name
|
|
|
|
|
out_names = [x.name for x in op_proto.outputs]
|
|
|
|
|
|
|
|
|
|
# TODO(Aurelius84): parse necessary out_dtype of custom op
|
|
|
|
|
out_infos = {out_name: ['float32']}
|
|
|
|
|
return in_names, out_infos
|
|
|
|
|
return in_names, out_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _import_module_from_library(module_name, build_directory, verbose=False):
|
|
|
|
|
@ -450,13 +447,10 @@ def _generate_python_module(module_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _custom_api_content(op_name):
|
|
|
|
|
params_str, ins_str = _get_api_inputs_str(op_name)
|
|
|
|
|
params_str, ins_str, outs_str = _get_api_inputs_str(op_name)
|
|
|
|
|
|
|
|
|
|
API_TEMPLATE = textwrap.dedent("""
|
|
|
|
|
from paddle.fluid.layer_helper import LayerHelper
|
|
|
|
|
from paddle.utils.cpp_extension import parse_op_info
|
|
|
|
|
|
|
|
|
|
_, _out_infos = parse_op_info('{op_name}')
|
|
|
|
|
|
|
|
|
|
def {op_name}({inputs}):
|
|
|
|
|
helper = LayerHelper("{op_name}", **locals())
|
|
|
|
|
@ -464,21 +458,22 @@ def _custom_api_content(op_name):
|
|
|
|
|
# prepare inputs and output
|
|
|
|
|
ins = {ins}
|
|
|
|
|
outs = {{}}
|
|
|
|
|
for out_name in _out_infos:
|
|
|
|
|
outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]]
|
|
|
|
|
out_names = {out_names}
|
|
|
|
|
for out_name in out_names:
|
|
|
|
|
# Set 'float32' temporarily, and the actual dtype of output variable will be inferred
|
|
|
|
|
# in runtime.
|
|
|
|
|
outs[out_name] = helper.create_variable(dtype='float32')
|
|
|
|
|
|
|
|
|
|
helper.append_op(type="{op_name}", inputs=ins, outputs=outs)
|
|
|
|
|
|
|
|
|
|
res = list(outs.values())[0]
|
|
|
|
|
if len(res) == 1:
|
|
|
|
|
return res[0]
|
|
|
|
|
else:
|
|
|
|
|
return res
|
|
|
|
|
res = [outs[out_name] for out_name in out_names]
|
|
|
|
|
|
|
|
|
|
return res[0] if len(res)==1 else res
|
|
|
|
|
""").lstrip()
|
|
|
|
|
|
|
|
|
|
# generate python api file
|
|
|
|
|
api_content = API_TEMPLATE.format(
|
|
|
|
|
op_name=op_name, inputs=params_str, ins=ins_str)
|
|
|
|
|
op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str)
|
|
|
|
|
|
|
|
|
|
return api_content
|
|
|
|
|
|
|
|
|
|
@ -509,13 +504,15 @@ def _get_api_inputs_str(op_name):
|
|
|
|
|
"""
|
|
|
|
|
Returns string of api parameters and inputs dict.
|
|
|
|
|
"""
|
|
|
|
|
in_names, _ = parse_op_info(op_name)
|
|
|
|
|
in_names, out_names = parse_op_info(op_name)
|
|
|
|
|
# e.g: x, y, z
|
|
|
|
|
params_str = ','.join([p.lower() for p in in_names])
|
|
|
|
|
# e.g: {'X': x, 'Y': y, 'Z': z}
|
|
|
|
|
ins_str = "{%s}" % ','.join(
|
|
|
|
|
["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names])
|
|
|
|
|
return params_str, ins_str
|
|
|
|
|
# e.g: ['Out', 'Index']
|
|
|
|
|
outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names])
|
|
|
|
|
return params_str, ins_str, outs_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _write_setup_file(name,
|
|
|
|
|
|