|
|
|
@ -15,16 +15,13 @@ import re
|
|
|
|
|
import cStringIO
|
|
|
|
|
import functools
|
|
|
|
|
import warnings
|
|
|
|
|
import string
|
|
|
|
|
|
|
|
|
|
from ..proto import framework_pb2
|
|
|
|
|
from ..framework import OpProtoHolder, Variable
|
|
|
|
|
from ..layer_helper import LayerHelper
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'deprecated',
|
|
|
|
|
'generate_layer_fn',
|
|
|
|
|
'autodoc',
|
|
|
|
|
]
|
|
|
|
|
__all__ = ['deprecated', 'generate_layer_fn', 'autodoc', 'templatedoc']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_(name):
|
|
|
|
@ -43,6 +40,10 @@ def _convert_(name):
|
|
|
|
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _type_to_str_(tp):
|
|
|
|
|
return framework_pb2.AttrType.Name(tp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_doc_string_(op_proto):
|
|
|
|
|
"""
|
|
|
|
|
Generate docstring by OpProto
|
|
|
|
@ -54,9 +55,6 @@ def _generate_doc_string_(op_proto):
|
|
|
|
|
str: the document string
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _type_to_str_(tp):
|
|
|
|
|
return framework_pb2.AttrType.Name(tp)
|
|
|
|
|
|
|
|
|
|
if not isinstance(op_proto, framework_pb2.OpProto):
|
|
|
|
|
raise TypeError("OpProto should be `framework_pb2.OpProto`")
|
|
|
|
|
|
|
|
|
@ -220,3 +218,42 @@ def autodoc(comment=""):
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
return __impl__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def templatedoc():
|
|
|
|
|
"""
|
|
|
|
|
Decorator of layer function. It will use the docstring from the layer
|
|
|
|
|
function as the template. The template arguments are:
|
|
|
|
|
|
|
|
|
|
* ${comment}: The operator comment written in CPP.
|
|
|
|
|
* ${{name}_comment}: The comment of ${name} written with AddAttr, AddOutput,
|
|
|
|
|
and AddInput. The ${name} is Python snake style. i.e., xxx_xxx.
|
|
|
|
|
* ${{name}_type}: The type of ${name}.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Decorated funciton.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __impl__(func):
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
|
|
|
|
|
tmpl = string.Template(func.__doc__)
|
|
|
|
|
args = {"comment": " ".join(op_proto.comment.split())}
|
|
|
|
|
for each_input in op_proto.inputs:
|
|
|
|
|
input_name = _convert_(each_input.name)
|
|
|
|
|
args["{0}_comment".format(input_name)] = each_input.comment
|
|
|
|
|
args["{0}_type".format(input_name)] = "Variable"
|
|
|
|
|
for each_attr in op_proto.attrs:
|
|
|
|
|
input_name = _convert_(each_attr.name)
|
|
|
|
|
args["{0}_comment".format(input_name)] = each_attr.comment
|
|
|
|
|
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
|
|
|
|
|
|
|
|
|
|
for each_opt in op_proto.outputs:
|
|
|
|
|
output_name = _convert_(each_opt.name)
|
|
|
|
|
args["{0}_comment".format(output_name)] = each_opt.comment
|
|
|
|
|
args["{0}_type".format(output_name)] = "Variable"
|
|
|
|
|
|
|
|
|
|
func.__doc__ = tmpl.substitute(args)
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
return __impl__
|
|
|
|
|