|
|
|
@ -145,6 +145,16 @@ class OpDescCreationMethod(object):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpInfo(object):
|
|
|
|
|
def __init__(self, name, method, inputs, outputs, attrs, no_temp_outputs):
|
|
|
|
|
self.name = name
|
|
|
|
|
self.method = method
|
|
|
|
|
self.inputs = inputs
|
|
|
|
|
self.outputs = outputs
|
|
|
|
|
self.attrs = attrs
|
|
|
|
|
self.no_temp_outputs = no_temp_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_op_creation_method(op_proto):
|
|
|
|
|
"""
|
|
|
|
|
Generate op creation method for an OpProto
|
|
|
|
@ -155,15 +165,15 @@ def create_op_creation_method(op_proto):
|
|
|
|
|
opdesc = method(*args, **kwargs)
|
|
|
|
|
return core.Operator.create(opdesc.SerializeToString())
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'method': __impl__,
|
|
|
|
|
'name': op_proto.type,
|
|
|
|
|
'all_inputs': [var.name for var in op_proto.inputs],
|
|
|
|
|
'all_outputs': [var.name for var in op_proto.outputs],
|
|
|
|
|
'all_attrs': [attr.name for attr in op_proto.attrs],
|
|
|
|
|
'all_no_temp_outputs':
|
|
|
|
|
[var.name for var in op_proto.outputs if not var.temporary]
|
|
|
|
|
}
|
|
|
|
|
return OpInfo(
|
|
|
|
|
method=__impl__,
|
|
|
|
|
name=op_proto.type,
|
|
|
|
|
inputs=[var.name for var in op_proto.inputs],
|
|
|
|
|
outputs=[var.name for var in op_proto.outputs],
|
|
|
|
|
attrs=[attr.name for attr in op_proto.attrs],
|
|
|
|
|
no_temp_outputs=[
|
|
|
|
|
var.name for var in op_proto.outputs if not var.temporary
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OperatorFactory(object):
|
|
|
|
@ -171,7 +181,7 @@ class OperatorFactory(object):
|
|
|
|
|
self.op_methods = dict()
|
|
|
|
|
for op_proto in get_all_op_protos():
|
|
|
|
|
method = create_op_creation_method(op_proto)
|
|
|
|
|
self.op_methods[method['name']] = method
|
|
|
|
|
self.op_methods[method.name] = method
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
if 'type' in kwargs:
|
|
|
|
@ -185,27 +195,27 @@ class OperatorFactory(object):
|
|
|
|
|
"argument except type")
|
|
|
|
|
t = args[0]
|
|
|
|
|
|
|
|
|
|
return self.get_op_creation_info(t)['method'](**kwargs)
|
|
|
|
|
return self.get_op_info(t).method(**kwargs)
|
|
|
|
|
|
|
|
|
|
def types(self):
|
|
|
|
|
return self.op_methods.keys()
|
|
|
|
|
|
|
|
|
|
def get_op_creation_info(self, t):
|
|
|
|
|
def get_op_info(self, t):
|
|
|
|
|
if t not in self.op_methods:
|
|
|
|
|
raise ValueError("operator %s is not registered", t)
|
|
|
|
|
return self.op_methods.get(t)
|
|
|
|
|
|
|
|
|
|
def get_op_input_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_inputs']
|
|
|
|
|
return self.get_op_info(type).inputs
|
|
|
|
|
|
|
|
|
|
def get_op_output_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_outputs']
|
|
|
|
|
return self.get_op_info(type).outputs
|
|
|
|
|
|
|
|
|
|
def get_op_attr_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_attrs']
|
|
|
|
|
return self.get_op_info(type).attrs
|
|
|
|
|
|
|
|
|
|
def get_op_no_temp_output_names(self, type):
|
|
|
|
|
return self.get_op_creation_info(type)['all_no_temp_outputs']
|
|
|
|
|
return self.get_op_info(type).no_temp_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Operator = OperatorFactory() # Default global factory
|
|
|
|
|