|
|
|
@ -1,10 +1,12 @@
|
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
|
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
|
|
|
|
|
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
|
|
|
|
|
Operator
|
|
|
|
|
from paddle.v2.framework.initializer import ConstantInitializer, \
|
|
|
|
|
NormalInitializer
|
|
|
|
|
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
|
|
|
|
|
import re
|
|
|
|
|
import cStringIO
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
|
|
|
|
@ -240,6 +242,58 @@ def _convert_(name):
|
|
|
|
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_doc_string_(op_proto):
|
|
|
|
|
"""
|
|
|
|
|
Generate docstring by OpProto
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: the document string
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _type_to_str_(tp):
|
|
|
|
|
return framework_pb2.AttrType.Name(tp)
|
|
|
|
|
|
|
|
|
|
if not isinstance(op_proto, framework_pb2.OpProto):
|
|
|
|
|
raise TypeError("OpProto should be `framework_pb2.OpProto`")
|
|
|
|
|
|
|
|
|
|
buf = cStringIO.StringIO()
|
|
|
|
|
buf.write(op_proto.comment)
|
|
|
|
|
buf.write('\nArgs:\n')
|
|
|
|
|
for each_input in op_proto.inputs:
|
|
|
|
|
line_begin = ' {0}: '.format(_convert_(each_input.name))
|
|
|
|
|
buf.write(line_begin)
|
|
|
|
|
buf.write(each_input.comment)
|
|
|
|
|
buf.write('\n')
|
|
|
|
|
buf.write(' ' * len(line_begin))
|
|
|
|
|
buf.write('Duplicable: ')
|
|
|
|
|
buf.write(str(each_input.duplicable))
|
|
|
|
|
buf.write(' Optional: ')
|
|
|
|
|
buf.write(str(each_input.dispensable))
|
|
|
|
|
buf.write('\n')
|
|
|
|
|
|
|
|
|
|
for each_attr in op_proto.attrs:
|
|
|
|
|
buf.write(' ')
|
|
|
|
|
buf.write(each_attr.name)
|
|
|
|
|
buf.write(' (')
|
|
|
|
|
buf.write(_type_to_str_(each_attr.type))
|
|
|
|
|
buf.write('): ')
|
|
|
|
|
buf.write(each_attr.comment)
|
|
|
|
|
buf.write('\n')
|
|
|
|
|
|
|
|
|
|
if len(op_proto.outputs) != 0:
|
|
|
|
|
buf.write('\nReturns:\n')
|
|
|
|
|
buf.write(' ')
|
|
|
|
|
for each_opt in op_proto.outputs:
|
|
|
|
|
if not each_opt.intermediate:
|
|
|
|
|
break
|
|
|
|
|
buf.write(each_opt.comment)
|
|
|
|
|
|
|
|
|
|
return buf.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_op_func_(op_type):
|
|
|
|
|
"""
|
|
|
|
|
Create an Operator for a Function.
|
|
|
|
@ -298,11 +352,6 @@ def _create_op_func_(op_type):
|
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
|
|
def func(**kwargs):
|
|
|
|
|
"""
|
|
|
|
|
This function implements the function for the operator. This process
|
|
|
|
|
involves doing the sanity check (using the function above), reading
|
|
|
|
|
inputs from protobuf and applying the activations on top.
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper(op_type, **kwargs)
|
|
|
|
|
|
|
|
|
|
dtype = infer_and_check_data_type(op_proto, **kwargs)
|
|
|
|
@ -326,6 +375,7 @@ def _create_op_func_(op_type):
|
|
|
|
|
|
|
|
|
|
func.__name__ = op_type
|
|
|
|
|
globals()[op_type] = func
|
|
|
|
|
func.__doc__ = _generate_doc_string_(op_proto)
|
|
|
|
|
global __all__
|
|
|
|
|
__all__.append(op_type)
|
|
|
|
|
|
|
|
|
|