From 80d2214361a2040e340a396b2353615c486eb480 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 30 Sep 2020 11:07:52 +0800 Subject: [PATCH] fix mssim precision when dtype is uint32. --- mindspore/nn/layer/image.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 159749c960..d8252910ee 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """image""" +import numbers import numpy as np import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor @@ -93,6 +94,16 @@ def _convert_img_dtype_to_float32(img, max_val): ret = ret * scale return ret +@constexpr +def _get_dtype_max(dtype): + """get max of the dtype""" + np_type = mstype.dtype_to_nptype(dtype) + if issubclass(np_type, numbers.Integral): + dtype_max = np.float64(np.iinfo(np_type).max) + else: + dtype_max = 1.0 + return dtype_max + @constexpr def _check_input_4d(input_shape, param_name, func_name): if len(input_shape) != 4: @@ -224,9 +235,11 @@ class SSIM(Cell): _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name) _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name) P.SameTypeShape()(img1, img2) - max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) - img1 = _convert_img_dtype_to_float32(img1, self.max_val) - img2 = _convert_img_dtype_to_float32(img2, self.max_val) + dtype_max_val = _get_dtype_max(F.dtype(img1)) + max_val = F.scalar_cast(self.max_val, F.dtype(img1)) + max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val) + img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) + img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) c1 = (self.k1 * max_val) ** 2 c2 = (self.k2 * max_val) ** 2 @@ -309,10 +322,13 @@ class MSSSIM(Cell): def construct(self, img1, img2): _check_input_4d(F.shape(img1), "img1", self.cls_name) _check_input_4d(F.shape(img2), "img2", self.cls_name) + _check_input_dtype(F.dtype(img1), 'img1', mstype.number_type, self.cls_name) P.SameTypeShape()(img1, img2) - max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) - img1 = _convert_img_dtype_to_float32(img1, self.max_val) - img2 = _convert_img_dtype_to_float32(img2, self.max_val) + dtype_max_val = _get_dtype_max(F.dtype(img1)) + max_val = F.scalar_cast(self.max_val, F.dtype(img1)) + max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val) + img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) + img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) c1 = (self.k1 * max_val) ** 2 c2 = (self.k2 * max_val) ** 2 @@ -375,9 +391,11 @@ class PSNR(Cell): _check_input_4d(F.shape(img1), "img1", self.cls_name) _check_input_4d(F.shape(img2), "img2", self.cls_name) P.SameTypeShape()(img1, img2) - max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) - img1 = _convert_img_dtype_to_float32(img1, self.max_val) - img2 = _convert_img_dtype_to_float32(img2, self.max_val) + dtype_max_val = _get_dtype_max(F.dtype(img1)) + max_val = F.scalar_cast(self.max_val, F.dtype(img1)) + max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val) + img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) + img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1)) psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)