!1729 fix bprop of ExtractImagePatches

Merge pull request !1729 from zhangbuxue/fix_bprop_of_ExtractImagePatches
pull/1729/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d8e06d0af5

@ -482,7 +482,7 @@ class Unfold(Cell):
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
data type is int8, float16, uint8.
data type is number.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',

@ -14,13 +14,12 @@
# ============================================================================
"""Define the grad rules of neural network related operations."""
from mindspore.common import dtype as mstype
from .grad_base import bprop_getters
from .. import functional as F
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprop_getters
from ... import context
@ -73,7 +72,7 @@ def get_bprop_extract_image_patches(self):
slice_op = P.Slice()
transpose = P.Transpose()
matmul = P.MatMul()
cast = P.Cast()
_, ksizes_row, ksizes_col, _ = self.ksizes
def bprop(x, out, dout):
@ -82,16 +81,13 @@ def get_bprop_extract_image_patches(self):
x_indices_num = x_row * x_col + 1
x_idx = F.tuple_to_array(range(1, x_indices_num))
x_idx = reshape(x_idx, (1, x_row, x_col, 1))
x_idx = cast(x_idx, mstype.float16)
x_idx_patch = extract_image_patches(x_idx)
x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2))
x_idx_patch = cast(x_idx_patch, mstype.int32)
out_shape = get_shape(out)
_, out_row, out_col, _ = out_shape
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
out_idx = F.tuple_to_array(range(out_indices_num))
out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col))
out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
idx_tensor = reshape(idx_tensor, (-1, 2))

@ -41,7 +41,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
data type is number.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
@ -94,5 +94,5 @@ class ExtractImagePatches(PrimitiveWithInfer):
def infer_dtype(self, input_x):
"""infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x

Loading…
Cancel
Save