From ea681bfb0f90b05220eb522f05b00118227726ce Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Tue, 16 Jun 2020 12:16:24 +0800 Subject: [PATCH] fix ssim filter size check --- mindspore/nn/layer/image.py | 9 +++++++-- tests/ut/python/ops/test_ops.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 7d8eef4d6f..39cc7895f3 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -104,6 +104,12 @@ def _check_input_4d(input_shape, param_name, func_name): raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}") return True +@constexpr +def _check_input_filter_size(input_shape, param_name, filter_size, func_name): + _check_input_4d(input_shape, param_name, func_name) + validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name) + validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name) + class SSIM(Cell): r""" Returns SSIM index between img1 and img2. @@ -154,8 +160,7 @@ class SSIM(Cell): self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) def construct(self, img1, img2): - _check_input_4d(F.shape(img1), "img1", self.cls_name) - _check_input_4d(F.shape(img2), "img2", self.cls_name) + _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name) P.SameTypeShape()(img1, img2) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 79d7de5d7d..95e5114fa3 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1754,6 +1754,10 @@ raise_set = [ 'block': (P.PReLU(), {'exception': ValueError}), 'desc_inputs': [[2], [1]], 'desc_bprop': [[1]]}), + ('SSIM', { + 'block': (nn.SSIM(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.ones((1, 3, 8, 8)), mstype.float32), + Tensor(np.ones((1, 3, 8, 8)), mstype.float32)]}), ]