|
|
|
@ -212,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True):
|
|
|
|
|
return proto.__str__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Variable(core.VariableBase):
|
|
|
|
|
class Variable(core.VarBase):
|
|
|
|
|
"""
|
|
|
|
|
In Fluid, every input and output of an operator is a variable. In most
|
|
|
|
|
cases, variables are used for holding different kinds of data or training
|
|
|
|
@ -507,7 +507,7 @@ class OpProtoHolder(object):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Operator(object):
|
|
|
|
|
class Operator(core.OpBase):
|
|
|
|
|
"""
|
|
|
|
|
In Fluid, all the operation are represented by Operator, and Operator
|
|
|
|
|
is regarded as a build in an instruction of a Block. Users can use the
|
|
|
|
@ -602,20 +602,20 @@ class Operator(object):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
self.inputs = [] if not inputs else inputs
|
|
|
|
|
for in_proto in proto.inputs:
|
|
|
|
|
found = find_name(inputs, in_proto.name)
|
|
|
|
|
found = find_name(self.inputs, in_proto.name)
|
|
|
|
|
assert found or in_proto.dispensable, "Input {} not found".format(
|
|
|
|
|
in_proto.name)
|
|
|
|
|
|
|
|
|
|
if found:
|
|
|
|
|
in_args = inputs[in_proto.name]
|
|
|
|
|
in_args = self.inputs[in_proto.name]
|
|
|
|
|
if not isinstance(in_args, list):
|
|
|
|
|
in_args = [in_args]
|
|
|
|
|
if not in_proto.duplicable and len(in_args) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input %s expects only one input, but %d are given."
|
|
|
|
|
% (in_proto.name, len(in_args)))
|
|
|
|
|
"Input %s expects only one input, but %d are given." %
|
|
|
|
|
(in_proto.name, len(in_args)))
|
|
|
|
|
in_arg_names = []
|
|
|
|
|
for arg in in_args:
|
|
|
|
|
if isinstance(arg, six.string_types):
|
|
|
|
@ -628,6 +628,7 @@ class Operator(object):
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_input(in_proto.name, [])
|
|
|
|
|
|
|
|
|
|
self.outputs = [] if not outputs else outputs
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
given = set()
|
|
|
|
|
need = set()
|
|
|
|
@ -1222,7 +1223,8 @@ class Block(object):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
op_desc = core.OpDesc()
|
|
|
|
|
op = Operator(block=self, desc=op_desc, *args, **kwargs)
|
|
|
|
|
_imperative_tracer().trace(op.desc)
|
|
|
|
|
sys.stderr.write('%s %s!!!\n' % (type(op.inputs), type(op.outputs)))
|
|
|
|
|
_imperative_tracer().trace(op, op.inputs, op.outputs)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
op_desc = self.desc.append_op()
|
|
|
|
|