polish printing dtype (#30682)

* polish printing dtype

* fix special case
revert-31068-fix_conv3d_windows
Leo Chen 5 years ago committed by GitHub
parent 5bf25d1e8b
commit 1a13626f5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,31 +26,25 @@ from .framework import Variable, default_main_program, _current_expected_place,
from .framework import _cpu_num, _cuda_ids from .framework import _cpu_num, _cuda_ids
__all__ = ['DataFeeder'] __all__ = ['DataFeeder']
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
core.VarDesc.VarType.BOOL: 'bool',
core.VarDesc.VarType.FP16: 'float16',
core.VarDesc.VarType.FP32: 'float32',
core.VarDesc.VarType.FP64: 'float64',
core.VarDesc.VarType.INT8: 'int8',
core.VarDesc.VarType.INT16: 'int16',
core.VarDesc.VarType.INT32: 'int32',
core.VarDesc.VarType.INT64: 'int64',
core.VarDesc.VarType.UINT8: 'uint8',
core.VarDesc.VarType.COMPLEX64: 'complex64',
core.VarDesc.VarType.COMPLEX128: 'complex128',
}
def convert_dtype(dtype): def convert_dtype(dtype):
if isinstance(dtype, core.VarDesc.VarType): if isinstance(dtype, core.VarDesc.VarType):
if dtype == core.VarDesc.VarType.BOOL: if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
return 'bool' return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.FP64:
return 'float64'
elif dtype == core.VarDesc.VarType.INT8:
return 'int8'
elif dtype == core.VarDesc.VarType.INT16:
return 'int16'
elif dtype == core.VarDesc.VarType.INT32:
return 'int32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
elif dtype == core.VarDesc.VarType.COMPLEX64:
return 'complex64'
elif dtype == core.VarDesc.VarType.COMPLEX128:
return 'complex128'
elif isinstance(dtype, type): elif isinstance(dtype, type):
if dtype in [ if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16, np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,

@ -23,6 +23,7 @@ from ..framework import Variable, Parameter, ParamBase
from .base import switch_to_static_graph from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss from .parallel import scale_loss
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
def monkey_patch_varbase(): def monkey_patch_varbase():
@ -319,5 +320,20 @@ def monkey_patch_varbase():
("__name__", "Tensor")): ("__name__", "Tensor")):
setattr(core.VarBase, method_name, method) setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
# So, we need to overwrite it to a more readable one.
# See details in https://github.com/pybind/pybind11/issues/2537.
origin = getattr(core.VarDesc.VarType, "__repr__")
def dtype_str(dtype):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
prefix = 'paddle.'
return prefix + _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
else:
# for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR
return origin(dtype)
setattr(core.VarDesc.VarType, "__repr__", dtype_str)
# patch math methods for varbase # patch math methods for varbase
monkey_patch_math_varbase() monkey_patch_math_varbase()

@ -617,6 +617,16 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static() paddle.enable_static()
def test_print_tensor_dtype(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.rand([1])
a_str = str(a.dtype)
expected = 'paddle.float32'
self.assertEqual(a_str, expected)
paddle.enable_static()
class TestVarBaseSetitem(unittest.TestCase): class TestVarBaseSetitem(unittest.TestCase):
def setUp(self): def setUp(self):

Loading…
Cancel
Save