|
|
|
@ -4,7 +4,10 @@ import collections
|
|
|
|
|
import numpy as np
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
__all__ = ['Block', 'Variable', 'Program', 'Operator', 'default_startup_program', 'default_main_program']
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
|
|
|
|
|
'default_main_program'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unique_name(prefix):
|
|
|
|
@ -232,17 +235,17 @@ class Operator(object):
|
|
|
|
|
in_proto.name)
|
|
|
|
|
|
|
|
|
|
if found:
|
|
|
|
|
in_argus = inputs[in_proto.name]
|
|
|
|
|
if not isinstance(in_argus, list):
|
|
|
|
|
in_argus = [in_argus]
|
|
|
|
|
if not in_proto.duplicable and len(in_argus) > 1:
|
|
|
|
|
in_args = inputs[in_proto.name]
|
|
|
|
|
if not isinstance(in_args, list):
|
|
|
|
|
in_args = [in_args]
|
|
|
|
|
if not in_proto.duplicable and len(in_args) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input %s expects only one input, but %d are given."
|
|
|
|
|
% (in_proto.name, len(in_argus)))
|
|
|
|
|
in_argu_names = []
|
|
|
|
|
for argu in in_argus:
|
|
|
|
|
in_argu_names.append(argu.name)
|
|
|
|
|
self.desc.set_input(in_proto.name, in_argu_names)
|
|
|
|
|
% (in_proto.name, len(in_args)))
|
|
|
|
|
in_arg_names = []
|
|
|
|
|
for arg in in_args:
|
|
|
|
|
in_arg_names.append(arg.name)
|
|
|
|
|
self.desc.set_input(in_proto.name, in_arg_names)
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_input(in_proto.name, [])
|
|
|
|
|
|
|
|
|
@ -260,18 +263,18 @@ class Operator(object):
|
|
|
|
|
str(e) for e in given)))
|
|
|
|
|
|
|
|
|
|
for out_proto in proto.outputs:
|
|
|
|
|
out_argus = outputs[out_proto.name]
|
|
|
|
|
if not isinstance(out_argus, list):
|
|
|
|
|
out_argus = [out_argus]
|
|
|
|
|
if not out_proto.duplicable and len(out_argus) > 1:
|
|
|
|
|
out_args = outputs[out_proto.name]
|
|
|
|
|
if not isinstance(out_args, list):
|
|
|
|
|
out_args = [out_args]
|
|
|
|
|
if not out_proto.duplicable and len(out_args) > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Output %s expects only one output, but %d are given." %
|
|
|
|
|
(out_proto.name, len(out_argus)))
|
|
|
|
|
out_argu_names = []
|
|
|
|
|
for argu in out_argus:
|
|
|
|
|
out_argu_names.append(argu.name)
|
|
|
|
|
argu.op = self
|
|
|
|
|
self.desc.set_output(out_proto.name, out_argu_names)
|
|
|
|
|
(out_proto.name, len(out_args)))
|
|
|
|
|
out_arg_names = []
|
|
|
|
|
for arg in out_args:
|
|
|
|
|
out_arg_names.append(arg.name)
|
|
|
|
|
arg.op = self
|
|
|
|
|
self.desc.set_output(out_proto.name, out_arg_names)
|
|
|
|
|
|
|
|
|
|
if attrs is not None:
|
|
|
|
|
if not isinstance(attrs, dict):
|
|
|
|
@ -582,8 +585,10 @@ class Parameter(Variable):
|
|
|
|
|
g_main_program = Program()
|
|
|
|
|
g_startup_program = Program()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_startup_program():
|
|
|
|
|
return g_startup_program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_main_program():
|
|
|
|
|
return g_main_program
|
|
|
|
|