From 0f89cc1da44428ee63afa9886c4c3db8d75e952c Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 27 Apr 2020 14:42:30 +0800 Subject: [PATCH] dock AcoshGrad for GE and AvgPool AvgPoolGrad for Vm --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + .../pass/const_input_to_attr_registry.cc | 1 + mindspore/ccsrc/transform/convert.cc | 2 + mindspore/ccsrc/transform/op_declare.cc | 5 + mindspore/ccsrc/transform/op_declare.h | 4 +- mindspore/nn/layer/basic.py | 3 +- mindspore/ops/_grad/grad_nn_ops.py | 45 +++++++-- mindspore/ops/_op_impl/tbe/__init__.py | 2 + mindspore/ops/_op_impl/tbe/avg_pool.py | 39 ++++++++ mindspore/ops/_op_impl/tbe/avg_pool_grad.py | 42 ++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/_grad_ops.py | 17 ++++ mindspore/ops/operations/_inner_ops.py | 98 +++++++++++++++++++ mindspore/ops/operations/nn_ops.py | 76 -------------- tests/ut/python/ops/test_ops.py | 4 +- 17 files changed, 256 insertions(+), 88 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/avg_pool.py create mode 100644 mindspore/ops/_op_impl/tbe/avg_pool_grad.py create mode 100644 mindspore/ops/operations/_inner_ops.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 44750fab4f..d5be2cbd29 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -38,6 +38,7 @@ static std::map tbe_func_adapter_map = { {"reduce_mean", "reduce_mean_d"}, {"reduce_max", "reduce_max_d"}, {"reduce_min", "reduce_min_d"}, + {"avg_pool_grad", "avg_pool_grad_d"}, {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, {"conv2d_backprop_input", "conv2d_backprop_input_d"}, {"depthwise_conv2d_native", "depthwise_conv2d"}, diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 407efe5689..6510ef79ea 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -170,6 +170,7 @@ const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); +const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index e938e5c64e..b37d068d94 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -178,6 +178,7 @@ extern const PrimitivePtr kPrimFusedBatchNorm; extern const PrimitivePtr kPrimConv2D; extern const PrimitivePtr kPrimMaxPool; extern const PrimitivePtr kPrimMaxPoolGrad; +extern const PrimitivePtr kPrimAvgPoolGrad; extern const PrimitivePtr kPrimFusedBatchNormGrad; extern const PrimitivePtr kPrimReluGrad; extern const PrimitivePtr kPrimConv2DBackpropInput; diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index 0b4263685b..3153a3bef9 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -25,6 +25,7 @@ namespace mindspore { namespace opt { ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimCast->name(), {1}); + Register(prim::kPrimAvgPoolGrad->name(), {0}); Register(prim::kPrimConv2DBackpropInput->name(), {2}); Register(prim::kPrimConv2DBackpropFilter->name(), {2}); Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 177f939f37..e7ea44b555 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -178,6 +178,7 @@ const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; const char kNameAcosh[] = "Acosh"; +const char kNameAcoshGrad[] = "AcoshGrad"; const char kNameFloorMod[] = "FloorMod"; const char kNameSpaceToDepth[] = "SpaceToDepth"; const char kNameDepthToSpace[] = "DepthToSpace"; @@ -375,6 +376,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, {string(kNameAcosh), ADPT_DESC(Acosh)}, + {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, {string(kNameFloorMod), ADPT_DESC(FloorMod)}, {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 477c915b15..b1195cfb1c 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -357,6 +357,11 @@ INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Acosh) = EMPTY_ATTR_MAP; OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}}; +// AcoshGrad +INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}}; + // Floor INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Floor) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 3be3546455..a2dc16c285 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -327,13 +327,15 @@ DECLARE_OP_ADAPTER(Const) DECLARE_OP_USE_OUTPUT(Const) DECLARE_OP_ADAPTER(Cos) DECLARE_OP_USE_OUTPUT(Cos) + DECLARE_OP_ADAPTER(Acos) DECLARE_OP_USE_OUTPUT(Acos) - DECLARE_OP_ADAPTER(AcosGrad) DECLARE_OP_USE_OUTPUT(AcosGrad) DECLARE_OP_ADAPTER(Acosh) DECLARE_OP_USE_OUTPUT(Acosh) +DECLARE_OP_ADAPTER(AcoshGrad) +DECLARE_OP_USE_OUTPUT(AcoshGrad) DECLARE_OP_ADAPTER(Floor) DECLARE_OP_USE_OUTPUT(Floor) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 2f8b38e818..9c8de85a68 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -21,6 +21,7 @@ from mindspore._checkparam import check_int_positive, check_bool from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.functional import identity +from mindspore.ops.operations import _inner_ops as inner from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore.common.api import ms_function @@ -480,7 +481,7 @@ class Unfold(Cell): """ def __init__(self, ksizes, strides, rates, padding="valid"): super(Unfold, self).__init__() - self.extract_image_patches = P.ExtractImagePatches(ksizes, strides, rates, padding) + self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) self.transpose = P.Transpose() self.format_NHWC = (0, 2, 3, 1) self.format_NCHW = (0, 3, 1, 2) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index baccdbbbb2..fc94544176 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -18,6 +18,7 @@ from mindspore.common import dtype as mstype from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G +from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like from .grad_base import bprop_getters @@ -29,6 +30,7 @@ def get_bprop_bias_add(self): def bprop(x, w, out, dout): return dout, bias_grad(dout) + return bprop @@ -49,18 +51,19 @@ def get_bprop_conv2d(self): dx = input_grad(dout, w, get_shape(x)) dw = filter_grad(dout, x, get_shape(w)) return dx, dw + return bprop -@bprop_getters.register(P.ExtractImagePatches) +@bprop_getters.register(inner.ExtractImagePatches) def get_bprop_extract_image_patches(self): """Grad definition for `ExtractImagePatches` operation.""" get_shape = P.Shape() reshape = P.Reshape() - extract_image_patches = P.ExtractImagePatches(ksizes=self.ksizes, - strides=self.strides, - rates=self.rates, - padding=self.padding) + extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes, + strides=self.strides, + rates=self.rates, + padding=self.padding) concat = P.Concat(axis=-1) expand_dims = P.ExpandDims() scatter_nd = P.ScatterNd() @@ -104,6 +107,7 @@ def get_bprop_extract_image_patches(self): dx = transpose(dx, (2, 0, 1, 3)) return (dx,) + return bprop @@ -124,6 +128,7 @@ def get_bprop_depthwise_conv2d_native(self): dx = input_grad(get_shape(x), w, dout) dw = filter_grad(x, get_shape(w), dout) return dx, dw + return bprop @@ -133,11 +138,12 @@ def get_bprop_max_pool_with_argmax(self): maxpool_grad = G.MaxPoolGradWithArgmax( ksize=self.ksize, strides=self.strides, - padding=self.padding,) + padding=self.padding) def bprop(x, out, dout): dx = maxpool_grad(x, dout[0], out[1]) return (dx,) + return bprop @@ -152,6 +158,7 @@ def get_bprop_max_pool_grad(self): def bprop(x, out, dout): dx = maxpool_grad(x, out, dout) return (dx,) + return bprop @@ -192,6 +199,7 @@ def get_bprop_dropout_gen_mask(self): def bprop(shape, keep_prob, out, dout): return (zeros_like(shape), zeros_like(keep_prob)) + return bprop @@ -202,6 +210,7 @@ def get_bprop_dropout_do_mask(self): def bprop(x, y, keep_prob, out, dout): return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob)) + return bprop @@ -213,6 +222,7 @@ def get_bprop_relu(self): def bprop(x, out, dout): dx = input_grad(dout, out) return (dx,) + return bprop @@ -224,6 +234,7 @@ def get_bprop_relu6(self): def bprop(x, out, dout): dx = input_grad(dout, x) return (dx,) + return bprop @@ -236,6 +247,7 @@ def get_bprop_relu_v2(self): mask = out[1] dx = input_grad(dout[0], mask) return (dx,) + return bprop @@ -247,6 +259,7 @@ def get_bprop_hswish(self): def bprop(x, out, dout): dx = input_grad(dout, x) return (dx,) + return bprop @@ -258,6 +271,7 @@ def get_bprop_hsigmoid(self): def bprop(x, out, dout): dx = input_grad(dout, x) return (dx,) + return bprop @@ -269,6 +283,7 @@ def get_bprop_elu(self): def bprop(x, out, dout): dx = input_grad(dout, x) return (dx,) + return bprop @@ -280,6 +295,7 @@ def get_bprop_sigmoid(self): def bprop(x, out, dout): dx = input_grad(out, dout) return (dx,) + return bprop @@ -294,6 +310,7 @@ def get_bprop_softmax(self): def bprop(x, out, dout): dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out) return (dx,) + return bprop @@ -305,6 +322,7 @@ def get_bprop_log_softmax(self): def bprop(x, out, dout): dx = logsoftmax_grad(out, dout) return (dx,) + return bprop @@ -316,6 +334,7 @@ def get_bprop_tanh(self): def bprop(x, out, dout): dx = logsoftmax_grad(out, dout) return (dx,) + return bprop @@ -327,6 +346,7 @@ def get_bprop_gelu(self): def bprop(x, out, dout): dx = input_grad(dout, x, out) return (dx,) + return bprop @@ -343,6 +363,7 @@ def get_bprop_fused_batch_norm(self): dscale = out[1] dbias = out[2] return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) + return bprop @@ -366,6 +387,7 @@ def get_bprop_batch_norm(self): dscale = out[1] dbias = out[2] return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) + return bprop @@ -377,6 +399,7 @@ def get_bprop_layer_norm(self): def bprop(x, gamma, beta, out, dout): dx, d_gamma, d_beta = layer_norm_grad(x, dout[0], out[2], out[1], gamma) return dx, d_gamma, d_beta + return bprop @@ -388,6 +411,7 @@ def get_bprop_l2normalize(self): def bprop(x, out, dout): dx = input_grad(x, out, dout) return (dx,) + return bprop @@ -400,6 +424,7 @@ def get_bprop_softmax_cross_entropy_with_logits(self): grad = out[1] grad = grad * expand(dout[0], -1) return grad, zeros_like(labels) + return bprop @@ -417,6 +442,7 @@ def get_bprop_sparse_softmax_cross_entropy_with_logits(self): grad = F.depend(grad, out) grad = grad * dout return grad, zeros_like(labels) + return bprop @@ -428,6 +454,7 @@ def get_bprop_resize_bilinear(self): def bprop(x, out, dout): dx = resize_grad(dout, x) return (dx,) + return bprop @@ -437,6 +464,7 @@ def get_bprop_onehot(self): def bprop(indices, depth, on_value, off_value, out, dout): return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value) + return bprop @@ -453,6 +481,7 @@ def get_bprop_top_kv2(self): updates = dout[0] shapes = shape_op(input_x) return scatter(indices, updates, shapes), zeros_like(k) + return bprop @@ -518,6 +547,7 @@ def get_bprop_lstm(self): dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state) dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state) return dx, dhx, dcx, dw + return bprop @@ -529,6 +559,7 @@ def get_bprop_sigmoid_crossentropy_with_logits(self): def bprop(x, y, out, dout): dx = op(x, y, dout) return (dx, zeros_like(y)) + return bprop @@ -545,6 +576,7 @@ def get_bprop_pad(self): shp = shape_op(x) dx = P.Slice()(dout, begin, shp) return (dx,) + return bprop @@ -556,6 +588,7 @@ def get_bprop_mirror_pad(self): def bprop(x, paddings, out, dout): dx = mirror_pad_grad(dout, paddings, x) return (dx, zeros_like(paddings)) + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index f9240ee325..ce1e02e915 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -151,3 +151,5 @@ from .greater_equal import _greater_equal_tbe from .not_equal import _not_equal_tbe from .floor_mod import _floor_mod_tbe from .scatter_nd_update import _scatter_nd_update_tbe +from .avg_pool import _avg_pool_tbe +from .avg_pool_grad import _avg_pool_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/avg_pool.py b/mindspore/ops/_op_impl/tbe/avg_pool.py new file mode 100644 index 0000000000..5db5947b01 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/avg_pool.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AvgPool op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +avg_pool_op_info = TBERegOp("AvgPool") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("avg_pool.so") \ + .compute_cost(10) \ + .kernel_name("avg_pool") \ + .partial_flag(True) \ + .attr("ksize", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("padding", "required", "str", "all") \ + .attr("data_format", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(avg_pool_op_info) +def _avg_pool_tbe(): + """AvgPool TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/avg_pool_grad.py b/mindspore/ops/_op_impl/tbe/avg_pool_grad.py new file mode 100644 index 0000000000..693636edcd --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/avg_pool_grad.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AvgPoolGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +avg_pool_grad_op_info = TBERegOp("AvgPoolGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("avg_pool_grad_d.so") \ + .compute_cost(10) \ + .kernel_name("avg_pool_grad_d") \ + .partial_flag(True) \ + .attr("x_origin", "required", "listInt", "all") \ + .attr("ksize", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("padding", "required", "str", "all") \ + .attr("data_format", "optional", "str", "all") \ + .input(0, "input_grad", False, "required", "all") \ + .input(1, "mean_matrix", False, "optional", "all") \ + .input(2, "kernel_matrix", False, "optional", "all") \ + .output(0, "out_grad", True, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(avg_pool_grad_op_info) +def _avg_pool_grad_tbe(): + """AvgPoolGrad TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 868d3b359e..d83f5accd0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, Gelu, Elu, GetNext, L2Normalize, LayerNorm, L2Loss, LogSoftmax, - MaxPool, ExtractImagePatches, + MaxPool, AvgPool, Conv2DBackpropInput, ConfusionMulGrad, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, @@ -89,7 +89,6 @@ __all__ = [ 'Sqrt', 'Square', 'Conv2D', - 'ExtractImagePatches', 'Flatten', 'MaxPoolWithArgmax', 'FusedBatchNorm', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index e130dcc382..747caa7a96 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -59,6 +59,23 @@ class ACosGrad(PrimitiveWithInfer): return x +class AcoshGrad(PrimitiveWithInfer): + """Performs grad of Acosh operation.""" + + @prim_attr_register + def __init__(self): + """init AcoshGrad""" + + def infer_shape(self, x, dout): + validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) + return x + + def infer_dtype(self, x, dout): + args = {"x": x, "dout": dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + return x + + class BatchNormGrad(PrimitiveWithInfer): """Performs grad of BatchNorm operation.""" diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py new file mode 100644 index 0000000000..632f9c0a20 --- /dev/null +++ b/mindspore/ops/operations/_inner_ops.py @@ -0,0 +1,98 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Inner operators.""" + +from ..._checkparam import Validator as validator +from ...common import dtype as mstype +from ..primitive import PrimitiveWithInfer, prim_attr_register + + +class ExtractImagePatches(PrimitiveWithInfer): + """ + Extract patches from images. + The input tensor must be a 4-D tensor and the data format is NHWC. + + Args: + ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int, + and the format is [1, ksize_row, ksize_col, 1]. + strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, + should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. + rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim + pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1]. + padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", + not case sensitive. Default: "valid". + + - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. + + - valid: Means that the patch area taken must be completely contained in the original image. + + Inputs: + - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and + data type is int8, float16, uint8. + + Outputs: + Tensor, a 4-D tensor whose data type is same as 'input_x', + and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch. + """ + + @prim_attr_register + def __init__(self, ksizes, strides, rates, padding="valid"): + """init""" + def _check_tuple_or_list(arg_name, arg_val, prim_name): + validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) + if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: + raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " + f"{arg_name}_col, 1], but got {arg_val}.") + if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: + raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " + f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " + f"is {arg_val[2]}") + + _check_tuple_or_list("ksize", ksizes, self.name) + _check_tuple_or_list("stride", strides, self.name) + _check_tuple_or_list("rate", rates, self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) + self.add_prim_attr("padding", self.padding) + + def infer_shape(self, input_x): + """infer shape""" + in_batch, in_row, in_col, in_depth = input_x + _, ksize_row, ksize_col, _ = self.ksizes + _, stride_row, stride_col, _ = self.strides + _, rate_row, rate_col, _ = self.rates + if len(input_x) != 4: + raise ValueError("The `input_x` should be a 4-D tensor, " + f"but got a {len(input_x)}-D tensor whose shape is {input_x}") + + out_batch = in_batch + out_depth = ksize_row * ksize_col * in_depth + + if self.padding == "VALID": + out_row = \ + (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1 + out_col = \ + (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1 + else: + out_row = (in_row - 1) // stride_row + 1 + out_col = (in_col - 1) // stride_col + 1 + + out_shape = [out_batch, out_row, out_col, out_depth] + return out_shape + + def infer_dtype(self, input_x): + """infer dtype""" + validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) + return input_x diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 66656b559e..c355707242 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2654,82 +2654,6 @@ class ApplyFtrl(PrimitiveWithInfer): return var_type -class ExtractImagePatches(PrimitiveWithInfer): - """ - Extract patches from images. - The input tensor must be a 4-D tensor and the data format is NHWC. - - Args: - ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int, - and the format is [1, ksize_row, ksize_col, 1]. - strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, - should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. - rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim - pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1]. - padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", - not case sensitive. Default: "valid". - - - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. - - - valid: Means that the patch area taken must be completely contained in the original image. - - Inputs: - - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and - data type is int8, float16, uint8. - - Outputs: - Tensor, a 4-D tensor whose data type is same as 'input_x', - and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch. - """ - - @prim_attr_register - def __init__(self, ksizes, strides, rates, padding="valid"): - """init""" - def _check_tuple_or_list(arg_name, arg_val, prim_name): - validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) - if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: - raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " - f"{arg_name}_col, 1], but got {arg_val}.") - if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: - raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " - f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " - f"is {arg_val[2]}") - - _check_tuple_or_list("ksize", ksizes, self.name) - _check_tuple_or_list("stride", strides, self.name) - _check_tuple_or_list("rate", rates, self.name) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) - self.add_prim_attr("padding", self.padding) - - def infer_shape(self, input_x): - in_batch, in_row, in_col, in_depth = input_x - _, ksize_row, ksize_col, _ = self.ksizes - _, stride_row, stride_col, _ = self.strides - _, rate_row, rate_col, _ = self.rates - if len(input_x) != 4: - raise ValueError("The `input_x` should be a 4-D tensor, " - f"but got a {len(input_x)}-D tensor whose shape is {input_x}") - - out_batch = in_batch - out_depth = ksize_row * ksize_col * in_depth - - if self.padding == "VALID": - out_row = \ - (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1 - out_col = \ - (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1 - else: - out_row = (in_row - 1) // stride_row + 1 - out_col = (in_col - 1) // stride_col + 1 - - out_shape = [out_batch, out_row, out_col, out_depth] - return out_shape - - def infer_dtype(self, input_x): - validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) - return input_x - - class ConfusionMulGrad(PrimitiveWithInfer): """ `output0` is the result of which input0 dot multily input1. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 3b8c50533b..68ff816fb3 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -265,8 +265,8 @@ test_case_math_ops = [ 'desc_bprop': [[2, 3]]}), ('Acosh', { 'block': P.Acosh(), - 'desc_inputs': [Tensor(np.random.rand(4).astype(np.float16))], - 'skip': ['backward']}), + 'desc_inputs': [[3, 4, 5]], + 'desc_bprop': [[3, 4, 5]]}), ('Sin', { 'block': P.Sin(), 'desc_inputs': [[2, 3]],