|
|
|
@ -29,35 +29,31 @@ class NetworkFunctor(object):
|
|
|
|
|
if ipt in kwargs:
|
|
|
|
|
var = kwargs[ipt]
|
|
|
|
|
if isinstance(var, basestring):
|
|
|
|
|
var_name = var
|
|
|
|
|
var = create_var(var)
|
|
|
|
|
self.net.var_name_map[var] = var_name
|
|
|
|
|
if not isinstance(var, core.Variable):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Input of op creation must be string or variable")
|
|
|
|
|
|
|
|
|
|
kwargs[ipt] = self.net.var_name_map[var]
|
|
|
|
|
kwargs[ipt] = get_cur_scope().get_var_name(var)
|
|
|
|
|
|
|
|
|
|
notemp_outputs = self.func.all_not_temp_output_args
|
|
|
|
|
|
|
|
|
|
for name in notemp_outputs:
|
|
|
|
|
if name not in kwargs:
|
|
|
|
|
kwargs[
|
|
|
|
|
name] = self.func.__name__ + "@OUT@%d" % self.net.generate_idx
|
|
|
|
|
self.net.generate_idx += 1
|
|
|
|
|
name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
outputs = self.func.all_output_args
|
|
|
|
|
for opt in outputs:
|
|
|
|
|
if opt in kwargs:
|
|
|
|
|
var = kwargs[opt]
|
|
|
|
|
if isinstance(var, basestring):
|
|
|
|
|
var_name = var
|
|
|
|
|
var = create_var(var)
|
|
|
|
|
self.net.var_name_map[var] = var_name
|
|
|
|
|
if not isinstance(var, core.Variable):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Output of op creation must be string or variable")
|
|
|
|
|
kwargs[opt] = self.net.var_name_map[var]
|
|
|
|
|
kwargs[opt] = get_cur_scope().get_var_name(var)
|
|
|
|
|
|
|
|
|
|
op = self.func(**kwargs)
|
|
|
|
|
|
|
|
|
@ -93,8 +89,6 @@ class Network(object):
|
|
|
|
|
self.net = core.Net.create()
|
|
|
|
|
funcs = (func_name for func_name in dir(op_creations)
|
|
|
|
|
if not func_name.startswith("__"))
|
|
|
|
|
self.generate_idx = 0
|
|
|
|
|
self.var_name_map = dict()
|
|
|
|
|
|
|
|
|
|
# TODO(yuyang18): This code can work, but do not generate a good
|
|
|
|
|
# docstring, try to give a better way generate function in runtime
|
|
|
|
|