|
|
|
@ -23,7 +23,7 @@ from paddle.fluid.proto import framework_pb2
|
|
|
|
|
from paddle.fluid.framework import OpProtoHolder, Variable
|
|
|
|
|
from paddle.fluid.layer_helper import LayerHelper
|
|
|
|
|
|
|
|
|
|
g_filer_attrs = ['op_role', 'op_role_var', 'op_namescope', 'dtype']
|
|
|
|
|
g_filer_attrs = ['op_role', 'op_role_var', 'op_namescope']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_(name):
|
|
|
|
@ -46,7 +46,7 @@ def _get_inputs(op_type):
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
|
|
|
|
|
inputs = dict()
|
|
|
|
|
for ipt in op_proto.inputs:
|
|
|
|
|
inputs[ipt.name] = ""
|
|
|
|
|
inputs[ipt.name] = ipt.comment
|
|
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
@ -60,6 +60,34 @@ def _get_outputs(op_type):
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
|
|
|
|
|
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
|
|
|
|
|
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def escape_math(text):
|
|
|
|
|
return _two_bang_pattern_.sub(
|
|
|
|
|
r'$$\1$$',
|
|
|
|
|
_single_dollar_pattern_.sub(r':math:`\1`',
|
|
|
|
|
_two_dollar_pattern_.sub(r"!!\1!!", text)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_comment(op_type):
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
|
|
|
|
|
|
|
|
|
|
comment_lines = op_proto.comment.split("\n")
|
|
|
|
|
comment = ""
|
|
|
|
|
for line in comment_lines:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if len(line) != 0:
|
|
|
|
|
comment += escape_math(line)
|
|
|
|
|
comment += " "
|
|
|
|
|
elif len(comment) != 0:
|
|
|
|
|
comment += "\n "
|
|
|
|
|
|
|
|
|
|
return comment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_attrs(op_type):
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
|
|
|
|
|
return op_proto.attrs
|
|
|
|
@ -77,14 +105,14 @@ def get_input_comments(op_type, indent=2):
|
|
|
|
|
ret = ""
|
|
|
|
|
inputs = _get_inputs(op_type)
|
|
|
|
|
for t in inputs:
|
|
|
|
|
ret += get_indent_space(2) + "input(${%s_type}): ${%s_comment}\n" % (
|
|
|
|
|
_convert_(t), _convert_(t))
|
|
|
|
|
ret += get_indent_space(2) + "%s (Type): %s\n" % (_convert_(t),
|
|
|
|
|
inputs[t])
|
|
|
|
|
|
|
|
|
|
for t in _get_attrs(op_type):
|
|
|
|
|
if t.name in g_filer_attrs:
|
|
|
|
|
continue
|
|
|
|
|
ret += get_indent_space(2) + "input(${%s_type}): ${%s_comment}\n" % (
|
|
|
|
|
_convert_(t.name), _convert_(t.name))
|
|
|
|
|
ret += get_indent_space(2) + "%s (%s): %s\n" % (
|
|
|
|
|
_convert_(t.name), t.type, _convert_(t.comment))
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
@ -122,7 +150,7 @@ def get_inputs(op_type):
|
|
|
|
|
ret = "inputs={"
|
|
|
|
|
inputs = _get_inputs(op_type)
|
|
|
|
|
for t in inputs:
|
|
|
|
|
ret += "{}={},".format(t, _convert_(t))
|
|
|
|
|
ret += "'{}': {},".format(t, _convert_(t))
|
|
|
|
|
ret = ret.strip(",")
|
|
|
|
|
ret += "}"
|
|
|
|
|
|
|
|
|
@ -132,39 +160,11 @@ def get_inputs(op_type):
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
def get_input_dtype(op_type):
|
|
|
|
|
dtype = None
|
|
|
|
|
for ipt in _get_inputs():
|
|
|
|
|
name = _convert_(ipt.name)
|
|
|
|
|
val = kwargs.pop(name, [])
|
|
|
|
|
if not isinstance(val, list) and not isinstance(val, tuple):
|
|
|
|
|
val = [val]
|
|
|
|
|
if len(val) == 0:
|
|
|
|
|
val = [args[0]]
|
|
|
|
|
args = args[1:]
|
|
|
|
|
|
|
|
|
|
for each in val:
|
|
|
|
|
if not isinstance(each, Variable):
|
|
|
|
|
raise ValueError("input of {0} must be variable".format(
|
|
|
|
|
op_type))
|
|
|
|
|
|
|
|
|
|
if dtype is None:
|
|
|
|
|
dtype = each.dtype
|
|
|
|
|
elif dtype != each.dtype:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"operator {0} must input same dtype. {1} vs {2}".format(
|
|
|
|
|
op_type, dtype, each.dtype))
|
|
|
|
|
|
|
|
|
|
return dtype
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_outputs(op_type):
|
|
|
|
|
ret = "outputs={"
|
|
|
|
|
inputs = _get_outputs(op_type)
|
|
|
|
|
for t in inputs:
|
|
|
|
|
ret += "{}={},".format(t, _convert_(t))
|
|
|
|
|
ret += "'{}': {},".format(t, _convert_(t))
|
|
|
|
|
ret = ret.strip(",")
|
|
|
|
|
ret += "}"
|
|
|
|
|
|
|
|
|
@ -174,44 +174,13 @@ def get_outputs(op_type):
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
attr_names = sorted(op.attr_names)
|
|
|
|
|
attrs_str = ""
|
|
|
|
|
for i in range(0, len(attr_names)):
|
|
|
|
|
name = attr_names[i]
|
|
|
|
|
|
|
|
|
|
attr_type = op.desc.attr_type(name)
|
|
|
|
|
if attr_type == core.AttrType.BLOCK:
|
|
|
|
|
a = "{name} = block[{value}]".format(
|
|
|
|
|
name=name, type=attr_type, value=op.block_attr_id(name))
|
|
|
|
|
attrs_str += a
|
|
|
|
|
if i != len(attr_names) - 1:
|
|
|
|
|
attrs_str += ", "
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if attr_type == core.AttrType.BLOCKS:
|
|
|
|
|
a = "{name} = blocks{value}".format(
|
|
|
|
|
name=name, type=attr_type, value=op.blocks_attr_ids(name))
|
|
|
|
|
attrs_str += a
|
|
|
|
|
if i != len(attr_names) - 1:
|
|
|
|
|
attrs_str += ", "
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
a = "{name} = {value}".format(
|
|
|
|
|
name=name, type=attr_type, value=op.desc.attr(name))
|
|
|
|
|
attrs_str += a
|
|
|
|
|
if i != len(attr_names) - 1:
|
|
|
|
|
attrs_str += ", "
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_attrs(op_type):
|
|
|
|
|
ret = "attrs={"
|
|
|
|
|
for t in _get_attrs(op_type):
|
|
|
|
|
if t.name in g_filer_attrs:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
ret += "%s=%s," % (t.name, _convert_(t.name))
|
|
|
|
|
ret += "'%s': %s," % (t.name, _convert_(t.name))
|
|
|
|
|
|
|
|
|
|
ret = ret.strip(",")
|
|
|
|
|
ret += "}"
|
|
|
|
@ -220,12 +189,13 @@ def get_attrs(op_type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_outvars(op_type, indent=1):
|
|
|
|
|
inputs = _get_inputs(op_type)
|
|
|
|
|
ret = ""
|
|
|
|
|
for t in _get_outputs(op_type):
|
|
|
|
|
ret += get_indent_space(
|
|
|
|
|
indent
|
|
|
|
|
) + "%s = helper.create_tmp_variable(dtype=helper.input_dtype())\n" % (
|
|
|
|
|
_convert_(t))
|
|
|
|
|
) + "%s = helper.create_tmp_variable(dtype=helper.input_dtype('%s'))\n" % (
|
|
|
|
|
(_convert_(t), list(inputs)[0]))
|
|
|
|
|
ret = ret.strip('\n')
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
@ -238,17 +208,15 @@ def get_op_py(op_type):
|
|
|
|
|
outputs = get_outputs(op_type)
|
|
|
|
|
attrs = get_attrs(op_type)
|
|
|
|
|
out_vars = get_outvars(op_type)
|
|
|
|
|
comment = get_comment(op_type)
|
|
|
|
|
|
|
|
|
|
code = """
|
|
|
|
|
@templatedoc()
|
|
|
|
|
def {op_type}({args}):
|
|
|
|
|
\"\"\"
|
|
|
|
|
{op_type}
|
|
|
|
|
|
|
|
|
|
{comment}
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
{input_comments}
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
{output_comments}
|
|
|
|
|
\"\"\"
|
|
|
|
@ -263,7 +231,7 @@ def {op_type}({args}):
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
""".format(
|
|
|
|
|
comment="${comment}",
|
|
|
|
|
comment=comment,
|
|
|
|
|
input_comments=input_comments.strip('\n'),
|
|
|
|
|
output_comments=output_comments,
|
|
|
|
|
args=args,
|
|
|
|
|