|
|
|
@ -563,6 +563,7 @@ class Operator(core.OpBase):
|
|
|
|
|
inputs=None,
|
|
|
|
|
outputs=None,
|
|
|
|
|
attrs=None):
|
|
|
|
|
core.OpBase.__init__(self)
|
|
|
|
|
self.block = block
|
|
|
|
|
self.desc = desc
|
|
|
|
|
# note: not add self.attrs here:
|
|
|
|
@ -602,33 +603,32 @@ class Operator(core.OpBase):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
self.inputs = [] if not inputs else inputs
|
|
|
|
|
for in_proto in proto.inputs:
|
|
|
|
|
found = find_name(self.inputs, in_proto.name)
|
|
|
|
|
assert found or in_proto.dispensable, "Input {} not found".format(
|
|
|
|
|
in_proto.name)
|
|
|
|
|
|
|
|
|
|
if found:
|
|
|
|
|
in_args = self.inputs[in_proto.name]
|
|
|
|
|
if not isinstance(in_args, list):
|
|
|
|
|
in_args = [in_args]
|
|
|
|
|
if not in_proto.duplicable and len(in_args) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input %s expects only one input, but %d are given." %
|
|
|
|
|
(in_proto.name, len(in_args)))
|
|
|
|
|
in_arg_names = []
|
|
|
|
|
for arg in in_args:
|
|
|
|
|
if isinstance(arg, six.string_types):
|
|
|
|
|
in_arg_names.append(arg)
|
|
|
|
|
elif isinstance(arg, six.binary_type):
|
|
|
|
|
in_arg_names.append(arg.decode())
|
|
|
|
|
else:
|
|
|
|
|
in_arg_names.append(cpt.to_text(arg.name))
|
|
|
|
|
self.desc.set_input(in_proto.name, in_arg_names)
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_input(in_proto.name, [])
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
for in_proto in proto.inputs:
|
|
|
|
|
found = find_name(inputs, in_proto.name)
|
|
|
|
|
assert found or in_proto.dispensable, "Input {} not found".format(
|
|
|
|
|
in_proto.name)
|
|
|
|
|
|
|
|
|
|
if found:
|
|
|
|
|
in_args = inputs[in_proto.name]
|
|
|
|
|
if not isinstance(in_args, list):
|
|
|
|
|
in_args = [in_args]
|
|
|
|
|
if not in_proto.duplicable and len(in_args) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input %s expects only one input, but %d are given."
|
|
|
|
|
% (in_proto.name, len(in_args)))
|
|
|
|
|
in_arg_names = []
|
|
|
|
|
for arg in in_args:
|
|
|
|
|
if isinstance(arg, six.string_types):
|
|
|
|
|
in_arg_names.append(arg)
|
|
|
|
|
elif isinstance(arg, six.binary_type):
|
|
|
|
|
in_arg_names.append(arg.decode())
|
|
|
|
|
else:
|
|
|
|
|
in_arg_names.append(cpt.to_text(arg.name))
|
|
|
|
|
self.desc.set_input(in_proto.name, in_arg_names)
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_input(in_proto.name, [])
|
|
|
|
|
|
|
|
|
|
self.outputs = [] if not outputs else outputs
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
given = set()
|
|
|
|
|
need = set()
|
|
|
|
@ -657,6 +657,21 @@ class Operator(core.OpBase):
|
|
|
|
|
arg.op = self
|
|
|
|
|
self.desc.set_output(out_proto.name, out_arg_names)
|
|
|
|
|
|
|
|
|
|
input_vars = []
|
|
|
|
|
for inp in inputs.values():
|
|
|
|
|
if isinstance(inp, Variable):
|
|
|
|
|
input_vars.append(inp)
|
|
|
|
|
elif isinstance(inp, list):
|
|
|
|
|
input_vars.extend(inp[:])
|
|
|
|
|
self.inputs = input_vars
|
|
|
|
|
output_vars = []
|
|
|
|
|
for out in outputs.values():
|
|
|
|
|
if isinstance(out, Variable):
|
|
|
|
|
output_vars.append(out)
|
|
|
|
|
elif isinstance(inp, list):
|
|
|
|
|
output_vars.extend(out[:])
|
|
|
|
|
self.outputs = output_vars
|
|
|
|
|
|
|
|
|
|
if op_attrs is not None:
|
|
|
|
|
if not isinstance(op_attrs, dict):
|
|
|
|
|
raise TypeError("'attrs' should be a dict.")
|
|
|
|
|