|
|
@ -7,40 +7,6 @@ import copy
|
|
|
|
__all__ = ['Block', 'Variable', 'Program', 'Operator']
|
|
|
|
__all__ = ['Block', 'Variable', 'Program', 'Operator']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_op_protos():
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get all registered op proto from PaddlePaddle C++ end.
|
|
|
|
|
|
|
|
:return: A list of registered OpProto.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
protostrs = core.get_all_op_protos()
|
|
|
|
|
|
|
|
ret_values = []
|
|
|
|
|
|
|
|
for pbstr in protostrs:
|
|
|
|
|
|
|
|
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
|
|
|
|
|
|
|
|
ret_values.append(op_proto)
|
|
|
|
|
|
|
|
return ret_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpProtoHolder(object):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def instance(cls):
|
|
|
|
|
|
|
|
if not hasattr(cls, '_instance'):
|
|
|
|
|
|
|
|
cls._instance = cls()
|
|
|
|
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
assert not hasattr(
|
|
|
|
|
|
|
|
self.__class__,
|
|
|
|
|
|
|
|
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
|
|
|
|
|
|
|
|
op_protos = get_all_op_protos()
|
|
|
|
|
|
|
|
self.op_proto_map = {}
|
|
|
|
|
|
|
|
for proto in op_protos:
|
|
|
|
|
|
|
|
self.op_proto_map[proto.type] = proto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_op_proto(self, type):
|
|
|
|
|
|
|
|
assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
|
|
|
|
|
|
|
|
return self.op_proto_map[type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Variable(object):
|
|
|
|
class Variable(object):
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
block,
|
|
|
|
block,
|
|
|
@ -141,6 +107,40 @@ class Variable(object):
|
|
|
|
raise ValueError("Not supported numpy dtype " + str(dtype))
|
|
|
|
raise ValueError("Not supported numpy dtype " + str(dtype))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_op_protos():
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get all registered op proto from PaddlePaddle C++ end.
|
|
|
|
|
|
|
|
:return: A list of registered OpProto.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
protostrs = core.get_all_op_protos()
|
|
|
|
|
|
|
|
ret_values = []
|
|
|
|
|
|
|
|
for pbstr in protostrs:
|
|
|
|
|
|
|
|
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
|
|
|
|
|
|
|
|
ret_values.append(op_proto)
|
|
|
|
|
|
|
|
return ret_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpProtoHolder(object):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def instance(cls):
|
|
|
|
|
|
|
|
if not hasattr(cls, '_instance'):
|
|
|
|
|
|
|
|
cls._instance = cls()
|
|
|
|
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
assert not hasattr(
|
|
|
|
|
|
|
|
self.__class__,
|
|
|
|
|
|
|
|
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
|
|
|
|
|
|
|
|
op_protos = get_all_op_protos()
|
|
|
|
|
|
|
|
self.op_proto_map = {}
|
|
|
|
|
|
|
|
for proto in op_protos:
|
|
|
|
|
|
|
|
self.op_proto_map[proto.type] = proto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_op_proto(self, type):
|
|
|
|
|
|
|
|
assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type
|
|
|
|
|
|
|
|
return self.op_proto_map[type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Operator(object):
|
|
|
|
class Operator(object):
|
|
|
|
def __init__(self, block, desc, type, inputs=None, outputs=None,
|
|
|
|
def __init__(self, block, desc, type, inputs=None, outputs=None,
|
|
|
|
attrs=None):
|
|
|
|
attrs=None):
|
|
|
|