|
|
@ -60,7 +60,7 @@ def convert_dtype(dtype):
|
|
|
|
u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8'
|
|
|
|
u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8'
|
|
|
|
]:
|
|
|
|
]:
|
|
|
|
# this code is a little bit dangerous, since error could happen
|
|
|
|
# this code is a little bit dangerous, since error could happen
|
|
|
|
# when casting no-asci code to str in python2.
|
|
|
|
# when casting no-ascii code to str in python2.
|
|
|
|
# but since the set itself is limited, so currently, it is good.
|
|
|
|
# but since the set itself is limited, so currently, it is good.
|
|
|
|
# however, jointly supporting python2 and python3, (as well as python4 maybe)
|
|
|
|
# however, jointly supporting python2 and python3, (as well as python4 maybe)
|
|
|
|
# may still be a long-lasting problem.
|
|
|
|
# may still be a long-lasting problem.
|
|
|
@ -76,8 +76,7 @@ def check_variable_and_dtype(input,
|
|
|
|
expected_dtype,
|
|
|
|
expected_dtype,
|
|
|
|
op_name,
|
|
|
|
op_name,
|
|
|
|
extra_message=''):
|
|
|
|
extra_message=''):
|
|
|
|
check_type(input, input_name, (Variable, core.VarBase), op_name,
|
|
|
|
check_type(input, input_name, Variable, op_name, extra_message)
|
|
|
|
extra_message)
|
|
|
|
|
|
|
|
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
|
|
|
|
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -91,6 +90,22 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
|
|
|
|
# each step in dynamic graph mode, it will bring a heavy performance burden.
|
|
|
|
# each step in dynamic graph mode, it will bring a heavy performance burden.
|
|
|
|
if in_dygraph_mode():
|
|
|
|
if in_dygraph_mode():
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .dygraph.dygraph_to_static.program_translator import in_declarative_mode
|
|
|
|
|
|
|
|
# NOTE: `in_declarative_mode` is used to determined whether this op is called under
|
|
|
|
|
|
|
|
# @declarative in transformation from dygrah to static layer. We add VarBase in
|
|
|
|
|
|
|
|
# expected_type to skip checking because varBase may be created and used in unusual way.
|
|
|
|
|
|
|
|
# Need a better design to be fix this.
|
|
|
|
|
|
|
|
if in_declarative_mode():
|
|
|
|
|
|
|
|
if not isinstance(expected_type, tuple):
|
|
|
|
|
|
|
|
expected_type = (expected_type, )
|
|
|
|
|
|
|
|
expected_type += (core.VarBase, )
|
|
|
|
|
|
|
|
elif isinstance(input, core.VarBase):
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. "
|
|
|
|
|
|
|
|
"Because received '{}' in {} is a imperative Variable.".format(
|
|
|
|
|
|
|
|
input_name, op_name))
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(input, expected_type):
|
|
|
|
if not isinstance(input, expected_type):
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
"The type of '%s' in %s must be %s, but received %s. %s" %
|
|
|
|
"The type of '%s' in %s must be %s, but received %s. %s" %
|
|
|
|