|
|
|
@ -58,6 +58,7 @@ class ImageGradients(Cell):
|
|
|
|
|
super(ImageGradients, self).__init__()
|
|
|
|
|
|
|
|
|
|
def construct(self, images):
|
|
|
|
|
_check_input_4d(F.shape(images), "images", self.cls_name)
|
|
|
|
|
batch_size, depth, height, width = P.Shape()(images)
|
|
|
|
|
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
|
|
|
|
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
|
|
@ -151,8 +152,8 @@ class SSIM(Cell):
|
|
|
|
|
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
|
|
|
|
|
|
|
|
|
|
def construct(self, img1, img2):
|
|
|
|
|
_check_input_4d(F.shape(img1), "img1", "SSIM")
|
|
|
|
|
_check_input_4d(F.shape(img2), "img2", "SSIM")
|
|
|
|
|
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
|
|
|
|
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
|
|
|
|
P.SameTypeShape()(img1, img2)
|
|
|
|
|
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
|
|
|
|
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
|
|
|
@ -244,8 +245,8 @@ class PSNR(Cell):
|
|
|
|
|
self.max_val = max_val
|
|
|
|
|
|
|
|
|
|
def construct(self, img1, img2):
|
|
|
|
|
_check_input_4d(F.shape(img1), "img1", "PSNR")
|
|
|
|
|
_check_input_4d(F.shape(img2), "img2", "PSNR")
|
|
|
|
|
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
|
|
|
|
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
|
|
|
|
P.SameTypeShape()(img1, img2)
|
|
|
|
|
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
|
|
|
|
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
|
|
|
|