From bbbd98072810c99353c9c65990ee28f1943a9eb0 Mon Sep 17 00:00:00 2001 From: lihongkang <[lihongkang1@huawei.com]> Date: Wed, 3 Jun 2020 15:13:11 +0800 Subject: [PATCH] add vm for histogramfixedwidth and dataformatdimmap --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/ops/_op_impl/tbe/__init__.py | 2 + .../ops/_op_impl/tbe/data_format_dim_map.py | 38 ++++++++++++++++ .../ops/_op_impl/tbe/histogram_fixed_width.py | 40 +++++++++++++++++ mindspore/ops/operations/__init__.py | 8 ++-- mindspore/ops/operations/math_ops.py | 44 +++++++++++++++++++ mindspore/ops/operations/nn_ops.py | 41 ++++++++++++++++- tests/ut/python/ops/test_ops.py | 10 +++++ 8 files changed, 180 insertions(+), 4 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/data_format_dim_map.py create mode 100644 mindspore/ops/_op_impl/tbe/histogram_fixed_width.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index e819261c64..1e3f55790a 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -111,6 +111,7 @@ static std::map tbe_func_adapter_map = { {"reduce_prod", "reduce_prod_d"}, {"a_cos", "acos"}, {"a_cos_grad", "acos_grad"}, + {"histogram_fixed_width", "histogram_fixed_width_d"}, {"broadcast_to", "broadcast_to_d"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 484827ec18..11470b265d 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -249,3 +249,5 @@ from .fused_mul_add_n_l2loss import _fused_mul_add_n_l2loss_tbe from .fused_mul_apply_momentum_extern import _fused_mul_apply_momentum_extern_tbe from .lamb_next_right import _lamb_next_right_tbe from .sparse_gather_v2 import _sparse_gather_v2_tbe +from .data_format_dim_map import _data_format_dim_map_tbe +from .histogram_fixed_width import _histogram_fixed_width_tbe diff --git a/mindspore/ops/_op_impl/tbe/data_format_dim_map.py b/mindspore/ops/_op_impl/tbe/data_format_dim_map.py new file mode 100644 index 0000000000..0bbccd30b1 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/data_format_dim_map.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""DataFormatDimMap op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +data_format_dim_map_op_info = TBERegOp("DataFormatDimMap") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("data_format_dim_map.so") \ + .compute_cost(10) \ + .kernel_name("data_format_dim_map") \ + .partial_flag(True) \ + .attr("dst_format", "optional", "str", "all") \ + .attr("src_format", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(data_format_dim_map_op_info) +def _data_format_dim_map_tbe(): + """DataFormatDimMap TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/histogram_fixed_width.py b/mindspore/ops/_op_impl/tbe/histogram_fixed_width.py new file mode 100644 index 0000000000..32195f1f3c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/histogram_fixed_width.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""HistogramFixedWidth op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +histogram_fixed_width_op_info = TBERegOp("HistogramFixedWidth") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("histogram_fixed_width_d.so") \ + .compute_cost(10) \ + .kernel_name("histogram_fixed_width_d") \ + .partial_flag(True) \ + .attr("nbins", "required", "int", "all") \ + .attr("dtype", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "range", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(histogram_fixed_width_op_info) +def _histogram_fixed_width_tbe(): + """HistogramFixedWidth TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 40f8e4dc95..2031210acd 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -49,7 +49,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2 Minimum, Mul, Neg, NMSWithMask, NotEqual, NPUAllocFloatStatus, NPUClearFloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, - Reciprocal, CumSum, + Reciprocal, CumSum, HistogramFixedWidth, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh) @@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl Gelu, Elu, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, LogSoftmax, - MaxPool, + MaxPool, DataFormatDimMap, AvgPool, Conv2DBackpropInput, ConfusionMulGrad, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, @@ -207,6 +207,7 @@ __all__ = [ 'ScatterNd', 'ScatterMax', 'ResizeNearestNeighbor', + 'HistogramFixedWidth', 'Pad', 'MirrorPad', 'GatherNd', @@ -298,7 +299,8 @@ __all__ = [ "BasicLSTMCell", "ConfusionMatrix", "BroadcastTo", - "Range" + "Range", + "DataFormatDimMap" ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 5b5871a17b..5d5e6fbcc3 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1043,6 +1043,50 @@ class Expm1(PrimitiveWithInfer): return x_type +class HistogramFixedWidth(PrimitiveWithInfer): + """ + Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal + width and determined by the arguments range and nbins. + + Args: + dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32". + nbins (Tensor): Number of histogram bins, the type is int32. + + Inputs: + - **x** (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16. + - **range** (Tensor) - Must have the same type as x. Shape [2] Tensor of same dtype as x. + x <= range[0] will be mapped to hist[0], x >= range[1] will be mapped to hist[-1]. + + Outputs: + Tensor, the type is int32. + + Examples: + >>> x = Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mindspore.float16) + >>> range = Tensor([0.0, 5.0], mindspore.float16) + >>> hist = P.HistogramFixedWidth(5) + >>> hist(x, range) + [2 1 1 0 2] + """ + + @prim_attr_register + def __init__(self, nbins, dtype='int32'): + self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) + valid_values = ['int32', 'int64'] + self.dtype = validator.check_string("dtype", dtype, valid_values, self.name) + self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) + + def infer_shape(self, x_shape, range_shape): + return (self.nbins,) + + def infer_dtype(self, x_dtype, range_dtype): + validator.check_subclass("x", x_dtype, mstype.tensor, self.name) + valid_types = (mstype.float16, mstype.float32, mstype.int32) + validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name) + y_dtype = mstype.int32 + return y_dtype + + class Log(PrimitiveWithInfer): """ Returns the natural logarithm of a tensor element-wise. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c4da91c180..a9da07d6c4 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1613,6 +1613,45 @@ class L2Loss(PrimitiveWithInfer): return x_type +class DataFormatDimMap(PrimitiveWithInfer): + """ + Returns the dimension index in the destination data format given the one in the source data format. + + Args: + src_format (string): An optional value for source data format. Default: 'NHWC'. + dst_format (string): An optional value for destination data format. Default: 'NCHW'. + + Inputs: + - **input_x** (Tensor) - A Tensor with each element as a dimension index in source data format. + Must be in the range [-4, 4). It's type is int32. + + Outputs: + Tensor, has the same type as the `input_x`. + + Examples: + >>> x = Tensor([0, 1, 2, 3], mindspore.int32) + >>> dfdm = P.DataFormatDimMap() + >>> dfdm(x) + [0 3 1 2] + """ + + @prim_attr_register + def __init__(self, src_format='NHWC', dst_format='NCHW'): + valid_values = ['NHWC', 'NCHW'] + self.src_format = validator.check_string("src_format", src_format, valid_values, self.name) + self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name) + self.init_prim_io_names(inputs=['input_x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + validator.check_subclass("x", x_type, mstype.tensor, self.name) + valid_types = [mstype.int32] + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + return x_type + + class SGD(PrimitiveWithInfer): """ Computes stochastic gradient descent (optionally with momentum). @@ -3735,7 +3774,7 @@ class BasicLSTMCell(PrimitiveWithInfer): validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) - validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name) + validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) ct_shape = c_shape ht_shape = h_shape it_shape = h_shape diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 7b3a48e25a..fada7df8be 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -764,6 +764,11 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.array([[24, 4, 13, 9], [1, 5, 10, 8]]).astype(np.int16))], 'desc_bprop': [], 'skip': ['backward']}), + ('HistogramFixedWidth', { + 'block': P.HistogramFixedWidth(5), + 'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)], + 'desc_bprop': [], + 'skip': ['backward']}), ] test_case_nn_ops = [ @@ -1203,6 +1208,11 @@ test_case_nn_ops = [ Tensor([[0.5, 0.4], [0.6, 0.1]], mstype.float32), Tensor([1, 1], mstype.int32)], 'desc_bprop': [Tensor([[0.7, 0.2], [0.1, 0.07]], mstype.float32)], 'skip': ['backward']}), + ('DataFormatDimMap', { + 'block': P.DataFormatDimMap(), + 'desc_inputs': [Tensor([0, 1, 2, 3], mstype.int32)], + 'desc_bprop': [], + 'skip': ['backward']}), ] test_case_array_ops = [