|
|
|
@ -1,9 +1,44 @@
|
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
|
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
|
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
__all__ = ['Block', 'Variable', 'Program', 'Operator']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_op_protos():
|
|
|
|
|
"""
|
|
|
|
|
Get all registered op proto from PaddlePaddle C++ end.
|
|
|
|
|
:return: A list of registered OpProto.
|
|
|
|
|
"""
|
|
|
|
|
protostrs = core.get_all_op_protos()
|
|
|
|
|
ret_values = []
|
|
|
|
|
for pbstr in protostrs:
|
|
|
|
|
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
|
|
|
|
|
ret_values.append(op_proto)
|
|
|
|
|
return ret_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpProtoHolder(object):
|
|
|
|
|
@classmethod
|
|
|
|
|
def instance(cls):
|
|
|
|
|
if not hasattr(cls, '_instance'):
|
|
|
|
|
cls._instance = cls()
|
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
assert not hasattr(
|
|
|
|
|
self.__class__,
|
|
|
|
|
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
|
|
|
|
|
op_protos = get_all_op_protos()
|
|
|
|
|
self.op_proto_map = {}
|
|
|
|
|
for proto in op_protos:
|
|
|
|
|
sefl.op_proto_map[proto.type] = proto
|
|
|
|
|
|
|
|
|
|
def get_op_proto(self, type):
|
|
|
|
|
assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
|
|
|
|
|
return self.op_proto_map[type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Variable(object):
|
|
|
|
|
def __init__(self, block, name=None, shape=None, dtype=None,
|
|
|
|
|
lod_level=None):
|
|
|
|
@ -36,27 +71,50 @@ class Variable(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Operator(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
block,
|
|
|
|
|
proto,
|
|
|
|
|
type=None,
|
|
|
|
|
inputs=None,
|
|
|
|
|
outputs=None,
|
|
|
|
|
def __init__(self, block, desc, type, inputs=None, outputs=None,
|
|
|
|
|
attrs=None):
|
|
|
|
|
self.block = block
|
|
|
|
|
self.proto = proto
|
|
|
|
|
if type is not None:
|
|
|
|
|
# TODO.
|
|
|
|
|
pass
|
|
|
|
|
self.desc = desc
|
|
|
|
|
self.proto = OpProtoHolder.instance().get_op_proto(type)
|
|
|
|
|
self.desc.set_type(type)
|
|
|
|
|
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
# TODO
|
|
|
|
|
pass
|
|
|
|
|
for in_proto in self.proto.inputs:
|
|
|
|
|
in_argus = inputs[in_proto.name]
|
|
|
|
|
if not isinstance(in_argus, list):
|
|
|
|
|
in_argus = [in_argus]
|
|
|
|
|
if not in_proto.duplicable and len(in_argus) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input %s expects only one input, but %d are given." %
|
|
|
|
|
(in_proto.name, len(in_argus)))
|
|
|
|
|
in_argu_names = []
|
|
|
|
|
for argu in in_argus:
|
|
|
|
|
in_argu_names.append(argu.name())
|
|
|
|
|
self.desc.set_input(in_proto.name, in_argu_names)
|
|
|
|
|
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
# TODO
|
|
|
|
|
pass
|
|
|
|
|
for out_proto in self.proto.outputs:
|
|
|
|
|
out_argus = outputs[out_proto.name]
|
|
|
|
|
if not isinstance(out_argus, list):
|
|
|
|
|
out_argus = [out_argus]
|
|
|
|
|
if not out_proto.duplicable and len(out_argus) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Output %s expects only one output, but %d are given." %
|
|
|
|
|
(out_proto.name, len(out_argus)))
|
|
|
|
|
out_argu_names = []
|
|
|
|
|
for argu in out_argus:
|
|
|
|
|
out_argu_names.append(argu.name())
|
|
|
|
|
self.desc.set_output(out_proto.name, out_argu_names)
|
|
|
|
|
|
|
|
|
|
if attrs is not None:
|
|
|
|
|
# TODO
|
|
|
|
|
pass
|
|
|
|
|
for attr in self.proto.attrs:
|
|
|
|
|
attr_name = attr.name
|
|
|
|
|
if not attr_name in attrs:
|
|
|
|
|
continue
|
|
|
|
|
if not isinstance(attrs[attr_name], Block):
|
|
|
|
|
self.desc.set_attr(attr_name, attrs[attr_name])
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
|
|
|
|
|
|
|
|
|
|
# TODO: Getters
|
|
|
|
|
|
|
|
|
@ -80,8 +138,8 @@ class Block(object):
|
|
|
|
|
return Variable(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def append_op(self, *args, **kwargs):
|
|
|
|
|
op_proto = self.proto.append_op()
|
|
|
|
|
op = Operator(self, op_proto, *args, **kwargs)
|
|
|
|
|
op_desc = self.proto.append_op()
|
|
|
|
|
op = Operator(self, op_desc, *args, **kwargs)
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|