|
|
|
@ -216,38 +216,54 @@ def create_op_creation_method(op_proto):
|
|
|
|
|
opdesc = method(*args, **kwargs)
|
|
|
|
|
return core.Operator.create(opdesc.SerializeToString())
|
|
|
|
|
|
|
|
|
|
__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
|
|
|
|
|
__impl__.all_input_args = [var.name for var in op_proto.inputs]
|
|
|
|
|
__impl__.all_output_args = [var.name for var in op_proto.outputs]
|
|
|
|
|
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs]
|
|
|
|
|
__impl__.all_not_temp_output_args = [
|
|
|
|
|
var.name for var in op_proto.outputs if not var.temporary
|
|
|
|
|
]
|
|
|
|
|
return {
|
|
|
|
|
'method': __impl__,
|
|
|
|
|
'name': op_proto.type,
|
|
|
|
|
'all_inputs': [var.name for var in op_proto.inputs],
|
|
|
|
|
'all_outputs': [var.name for var in op_proto.outputs],
|
|
|
|
|
'all_attrs': [attr.name for attr in op_proto.attrs],
|
|
|
|
|
'all_no_temp_outputs':
|
|
|
|
|
[var.name for var in op_proto.outputs if not var.temporary]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OperatorFactory(object):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.op_methods = dict()
|
|
|
|
|
for op_proto in get_all_op_protos():
|
|
|
|
|
method = create_op_creation_method(op_proto)
|
|
|
|
|
self.op_methods[method.name] = method
|
|
|
|
|
|
|
|
|
|
return __impl__
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
if 'type' in kwargs:
|
|
|
|
|
if len(args) != 0:
|
|
|
|
|
raise ValueError("All Paddle argument should be key-word "
|
|
|
|
|
"argument except type")
|
|
|
|
|
t = kwargs.pop('type')
|
|
|
|
|
else:
|
|
|
|
|
if len(args) != 1:
|
|
|
|
|
raise ValueError("All Paddle argument should be key-word "
|
|
|
|
|
"argument except type")
|
|
|
|
|
t = args[0]
|
|
|
|
|
|
|
|
|
|
return self.get_op_creation_info(t)['method'](**kwargs)
|
|
|
|
|
|
|
|
|
|
class OpCreationsHolder(object):
|
|
|
|
|
"""
|
|
|
|
|
A object will holds all op creation methods.
|
|
|
|
|
|
|
|
|
|
Use `op_creations.xxx_op` to access them.
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
def get_op_creation_info(self, t):
|
|
|
|
|
if t not in self.op_methods:
|
|
|
|
|
raise ValueError("operator %s is not registered", t)
|
|
|
|
|
return self.op_methods.get(t)
|
|
|
|
|
|
|
|
|
|
def get_op_input_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_inputs']
|
|
|
|
|
|
|
|
|
|
op_creations = OpCreationsHolder()
|
|
|
|
|
def get_op_output_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_outputs']
|
|
|
|
|
|
|
|
|
|
def get_op_attr_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_attrs']
|
|
|
|
|
|
|
|
|
|
def __bootstrap__():
|
|
|
|
|
"""
|
|
|
|
|
Bootstrap function for this module. It will dynamic create all op creation
|
|
|
|
|
methods in runtime.
|
|
|
|
|
"""
|
|
|
|
|
for op_proto in get_all_op_protos():
|
|
|
|
|
func = create_op_creation_method(op_proto)
|
|
|
|
|
func.__name__ = str(op_proto.type)
|
|
|
|
|
setattr(op_creations, func.__name__, func)
|
|
|
|
|
def get_op_no_temp_output_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_no_temp_outputs']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__bootstrap__()
|
|
|
|
|
Operator = OperatorFactory() # Default global factory
|