|
|
|
@ -71,20 +71,18 @@ class Variable(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Operator(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
block,
|
|
|
|
|
proto,
|
|
|
|
|
type=None,
|
|
|
|
|
inputs=None,
|
|
|
|
|
outputs=None,
|
|
|
|
|
def __init__(self, block, desc, type, inputs=None, outputs=None,
|
|
|
|
|
attrs=None):
|
|
|
|
|
self.block = block
|
|
|
|
|
self.proto = proto
|
|
|
|
|
if type is not None:
|
|
|
|
|
self.proto.set_type(type)
|
|
|
|
|
self.desc = desc
|
|
|
|
|
self.proto = OpProtoHolder.instance().get_op_proto(type)
|
|
|
|
|
self.desc.set_type(type)
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
for k, v in inputs.iteritems():
|
|
|
|
|
self.proto.set_input(k, v)
|
|
|
|
|
for in_proto in self.proto.inputs:
|
|
|
|
|
in_argu = inputs[in_proto.name]
|
|
|
|
|
if is_str(in_argu):
|
|
|
|
|
in_argu = [in_argu]
|
|
|
|
|
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
for k, v in outputs.iteritems():
|
|
|
|
|
self.proto.set_output(k, v)
|
|
|
|
@ -114,8 +112,8 @@ class Block(object):
|
|
|
|
|
return Variable(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def append_op(self, *args, **kwargs):
|
|
|
|
|
op_proto = self.proto.append_op()
|
|
|
|
|
op = Operator(self, op_proto, *args, **kwargs)
|
|
|
|
|
op_desc = self.proto.append_op()
|
|
|
|
|
op = Operator(self, op_desc, *args, **kwargs)
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|