!5904 fix image gradients height 1 bug

Merge pull request !5904 from zhaozhenlong/cloud/issue/image_gradient_h_1
pull/5904/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit bfcff2273c

@ -66,13 +66,19 @@ class ImageGradients(Cell):
check = _check_input_4d(F.shape(images), "images", self.cls_name) check = _check_input_4d(F.shape(images), "images", self.cls_name)
images = F.depend(images, check) images = F.depend(images, check)
batch_size, depth, height, width = P.Shape()(images) batch_size, depth, height, width = P.Shape()(images)
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] if height == 1:
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dy = P.Concat(2)((dy, dy_last)) else:
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) dy = P.Concat(2)((dy, dy_last))
dx = P.Concat(3)((dx, dx_last))
if width == 1:
dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
else:
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
dx = P.Concat(3)((dx, dx_last))
return dy, dx return dy, dx

Loading…
Cancel
Save