|
|
|
@ -44,6 +44,11 @@ def _type_to_str_(tp):
|
|
|
|
|
return framework_pb2.AttrType.Name(tp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
|
|
|
|
|
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
|
|
|
|
|
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_doc_string_(op_proto):
|
|
|
|
|
"""
|
|
|
|
|
Generate docstring by OpProto
|
|
|
|
@ -55,22 +60,27 @@ def _generate_doc_string_(op_proto):
|
|
|
|
|
str: the document string
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def escape_math(text):
|
|
|
|
|
return _two_bang_pattern_.sub(
|
|
|
|
|
r'$$\1$$',
|
|
|
|
|
_single_dollar_pattern_.sub(
|
|
|
|
|
r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text)))
|
|
|
|
|
|
|
|
|
|
if not isinstance(op_proto, framework_pb2.OpProto):
|
|
|
|
|
raise TypeError("OpProto should be `framework_pb2.OpProto`")
|
|
|
|
|
|
|
|
|
|
buf = cStringIO.StringIO()
|
|
|
|
|
buf.write(op_proto.comment)
|
|
|
|
|
buf.write(escape_math(op_proto.comment))
|
|
|
|
|
buf.write('\nArgs:\n')
|
|
|
|
|
for each_input in op_proto.inputs:
|
|
|
|
|
line_begin = ' {0}: '.format(_convert_(each_input.name))
|
|
|
|
|
buf.write(line_begin)
|
|
|
|
|
buf.write(each_input.comment)
|
|
|
|
|
buf.write(escape_math(each_input.comment))
|
|
|
|
|
buf.write('\n')
|
|
|
|
|
buf.write(' ' * len(line_begin))
|
|
|
|
|
buf.write('Duplicable: ')
|
|
|
|
|
buf.write(str(each_input.duplicable))
|
|
|
|
|
buf.write(' Optional: ')
|
|
|
|
|
buf.write(str(each_input.dispensable))
|
|
|
|
|
if each_input.duplicable:
|
|
|
|
|
buf.write(" Duplicatable.")
|
|
|
|
|
if each_input.dispensable:
|
|
|
|
|
buf.write(" Optional.")
|
|
|
|
|
buf.write('\n')
|
|
|
|
|
|
|
|
|
|
skip_attrs = OpProtoHolder.generated_op_attr_names()
|
|
|
|
@ -83,7 +93,7 @@ def _generate_doc_string_(op_proto):
|
|
|
|
|
buf.write(' (')
|
|
|
|
|
buf.write(_type_to_str_(each_attr.type))
|
|
|
|
|
buf.write('): ')
|
|
|
|
|
buf.write(each_attr.comment)
|
|
|
|
|
buf.write(escape_math(each_attr.comment))
|
|
|
|
|
buf.write('\n')
|
|
|
|
|
|
|
|
|
|
if len(op_proto.outputs) != 0:
|
|
|
|
@ -92,7 +102,7 @@ def _generate_doc_string_(op_proto):
|
|
|
|
|
for each_opt in op_proto.outputs:
|
|
|
|
|
if not each_opt.intermediate:
|
|
|
|
|
break
|
|
|
|
|
buf.write(each_opt.comment)
|
|
|
|
|
buf.write(escape_math(each_opt.comment))
|
|
|
|
|
|
|
|
|
|
return buf.getvalue()
|
|
|
|
|
|
|
|
|
|