fix bug that tuple(Variable) is converted to list(Variable) uncorrectly (#21687)

paddle_tiny_install
zhouwei25 5 years ago committed by Tao Luo
parent a5159d8480
commit e92d113590

@ -1814,8 +1814,8 @@ class Operator(object):
"The type of '%s' in operator %s should be "
"one of [basestring(), str, Varibale] in python2, "
"or one of [str, bytes, Variable] in python3."
"but received : " % (in_proto.name, type),
arg)
"but received : %s" %
(in_proto.name, type, arg))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])

@ -12280,16 +12280,18 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x = []
elif isinstance(x, Variable):
x = [x]
elif not isinstance(x, (list, tuple)):
elif isinstance(x, tuple):
x = list(x)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
if out is None:
out_list = []
elif isinstance(out, Variable):
out_list = [out]
elif isinstance(out, (list, tuple)):
out_list = out
else:
elif isinstance(out, tuple):
out_list = list(out)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')

Loading…
Cancel
Save