skip type/dtype check in dygraph mode, test=develop (#22915)

revert-22710-feature/integrated_ps_api
Chen Weihang 5 years ago committed by GitHub
parent c5cbe7f07b
commit e081c7a05d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,7 @@ from six.moves import zip, range, xrange
import multiprocessing
import warnings
from .framework import Variable, default_main_program, _current_expected_place
from .framework import Variable, default_main_program, _current_expected_place, in_dygraph_mode
from .framework import _cpu_num, _cuda_ids
__all__ = ['DataFeeder']
@ -81,6 +81,15 @@ def check_variable_and_dtype(input,
def check_type(input, input_name, expected_type, op_name, extra_message=''):
# NOTE [ Why skip dynamic graph check ]:
# 1. If the input type / dtype of a layer is wrong, it will be reported
# directly on that line. User can easily print the relevant information
# on which line. It is easier to debug, so there is no need to check
# in dynamic graph mode.
# 2. Performance considerations. Because these checks are executed at
# each step in dynamic graph mode, it will bring a heavy performance burden.
if in_dygraph_mode():
return
if not isinstance(input, expected_type):
raise TypeError(
"The type of '%s' in %s must be %s, but received %s. %s" %
@ -92,6 +101,9 @@ def check_dtype(input_dtype,
expected_dtype,
op_name,
extra_message=''):
# See NOTE [ Why skip dynamic graph check ]
if in_dygraph_mode():
return
if convert_dtype(input_dtype) in ['float16']:
warnings.warn(
"The data type of '%s' in %s only support float16 in GPU now. %s" %

Loading…
Cancel
Save