diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 43e3a903ef..738ae9c3c2 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -53,7 +53,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A NPUAllocFloatStatus, NPUClearFloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, - Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR, + Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, @@ -98,7 +98,6 @@ __all__ = [ 'EditDistance', 'CropAndResize', 'TensorAdd', - 'IFMR', 'Argmax', 'Argmin', 'ArgMaxWithValue', diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 932e10f1bd..b960bb3605 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -43,7 +43,8 @@ __all__ = ["MinMaxUpdatePerLayer", "BatchNormFoldGradD", "BatchNormFold2_D", "BatchNormFold2GradD", - "BatchNormFold2GradReduce" + "BatchNormFold2GradReduce", + "IFMR" ] @@ -1384,3 +1385,66 @@ class WtsARQ(PrimitiveWithInfer): validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name) validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name) return w_dtype + + +class IFMR(PrimitiveWithInfer): + """ + The TFMR(Input Feature Map Reconstruction). + + Args: + min_percentile (float): Min init percentile. Default: 0.999999. + max_percentile (float): Max init percentile. Default: 0.999999. + search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3]. + search_step (float): Step size of searching. Default: 0.01. + with_offset (bool): Whether using offset. Default: True. + + Inputs: + - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type. + - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`. + With float16 or float32 data type. + - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`. + With float16 or float32 data type. + - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type. + + Outputs: + - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32. + - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32. + + Examples: + >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) + >>> data_min = Tensor([0.1], mstype.float32) + >>> data_max = Tensor([0.5], mstype.float32) + >>> cumsum = Tensor(np.random.rand(4).astype(np.int32)) + >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), + >>> search_step=1.0, with_offset=False) + >>> output = ifmr(data, data_min, data_max, cumsum) + ([7.87401572e-03], [0.00000000e+00]) + """ + + @prim_attr_register + def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01, + with_offset=True): + validator.check_value_type("min_percentile", min_percentile, [float], self.name) + validator.check_value_type("max_percentile", max_percentile, [float], self.name) + validator.check_value_type("search_range", search_range, [list, tuple], self.name) + for item in search_range: + validator.check_positive_float(item, "item of search_range", self.name) + validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) + validator.check_value_type("search_step", search_step, [float], self.name) + validator.check_value_type("offset_flag", with_offset, [bool], self.name) + + def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): + validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) + validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) + validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) + validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) + validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) + return (1,), (1,) + + def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("input_value", "input_min", "input_max"), + (data_dtype, data_min_dtype, data_max_dtype))) + validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) + return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 6d08e7576f..0d751f76cd 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -16,7 +16,6 @@ """Operators for math.""" import copy -from functools import partial import numpy as np from ... import context @@ -3680,66 +3679,3 @@ class Eps(PrimitiveWithInfer): 'dtype': input_x['dtype'], } return out - - -class IFMR(PrimitiveWithInfer): - """ - The TFMR(Input Feature Map Reconstruction). - - Args: - min_percentile (float): Min init percentile. Default: 0.999999. - max_percentile (float): Max init percentile. Default: 0.999999. - search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3]. - search_step (float): Step size of searching. Default: 0.01. - with_offset (bool): Whether using offset. Default: True. - - Inputs: - - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type. - - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`. - With float16 or float32 data type. - - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`. - With float16 or float32 data type. - - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type. - - Outputs: - - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32. - - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32. - - Examples: - >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) - >>> data_min = Tensor([0.1], mstype.float32) - >>> data_max = Tensor([0.5], mstype.float32) - >>> cumsum = Tensor(np.random.rand(4).astype(np.int32)) - >>> ifmr = P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), - search_step=1.0, with_offset=False) - >>> output = ifmr(data, data_min, data_max, cumsum) - ([7.87401572e-03], [0.00000000e+00]) - """ - - @prim_attr_register - def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01, - with_offset=True): - validator.check_value_type("min_percentile", min_percentile, [float], self.name) - validator.check_value_type("max_percentile", max_percentile, [float], self.name) - validator.check_value_type("search_range", search_range, [list, tuple], self.name) - for item in search_range: - validator.check_positive_float(item, "item of search_range", self.name) - validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) - validator.check_value_type("search_step", search_step, [float], self.name) - validator.check_value_type("offset_flag", with_offset, [bool], self.name) - - def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): - validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) - validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) - validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) - validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) - validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) - return (1,), (1,) - - def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): - tuple(map(partial(validator.check_tensor_dtype_valid, - valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), - ("input_value", "input_min", "input_max"), - (data_dtype, data_min_dtype, data_max_dtype))) - validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) - return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 5a6c95db71..6d72f593dd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -601,10 +601,10 @@ class FusedBatchNorm(Primitive): Inputs: - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. - - **scale** (Tensor) - Tensor of shape :math:`(C,)`. - - **bias** (Tensor) - Tensor of shape :math:`(C,)`. - - **mean** (Tensor) - Tensor of shape :math:`(C,)`. - - **variance** (Tensor) - Tensor of shape :math:`(C,)`. + - **scale** (Parameter) - Tensor of shape :math:`(C,)`. + - **bias** (Parameter) - Tensor of shape :math:`(C,)`. + - **mean** (Parameter) - Tensor of shape :math:`(C,)`. + - **variance** (Parameter) - Tensor of shape :math:`(C,)`. Outputs: Tuple of 5 Tensor, the normalized input and the updated parameters. @@ -616,13 +616,30 @@ class FusedBatchNorm(Primitive): - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class FusedBatchNormNet(nn.Cell): + >>> def __init__(self): + >>> super(FusedBatchNormNet, self).__init__() + >>> self.fused_batch_norm = P.FusedBatchNorm() + >>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale") + >>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias") + >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") + >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") + >>> + >>> def construct(self, input_x): + >>> out = self.fused_batch_norm(input_x, self.scale, self.bias, self.mean, self.variance) + >>> return out + >>> >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) - >>> scale = Tensor(np.ones([64]), mindspore.float32) - >>> bias = Tensor(np.ones([64]), mindspore.float32) - >>> mean = Tensor(np.ones([64]), mindspore.float32) - >>> variance = Tensor(np.ones([64]), mindspore.float32) - >>> op = P.FusedBatchNorm() - >>> output = op(input_x, scale, bias, mean, variance) + >>> net = FusedBatchNormNet() + >>> output = net(input_x) + >>> output[0].shape + (128, 64, 32, 64) """ __mindspore_signature__ = ( sig.make_sig('input_x', dtype=sig.sig_dtype.T2), @@ -673,12 +690,12 @@ class FusedBatchNormEx(PrimitiveWithInfer): Inputs: - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, data type: float16 or float32. - - **scale** (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, + - **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, data type: float32. - - **bias** (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`, + - **bias** (Parameter) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`, data type: float32. - - **mean** (Tensor) - mean value, Tensor of shape :math:`(C,)`, data type: float32. - - **variance** (Tensor) - variance value, Tensor of shape :math:`(C,)`, data type: float32. + - **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32. + - **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32. Outputs: Tuple of 6 Tensors, the normalized input, the updated parameters and reserve. @@ -692,13 +709,30 @@ class FusedBatchNormEx(PrimitiveWithInfer): - **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32. Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class FusedBatchNormExNet(nn.Cell): + >>> def __init__(self): + >>> super(FusedBatchNormExNet, self).__init__() + >>> self.fused_batch_norm_ex = P.FusedBatchNormEx() + >>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale") + >>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias") + >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") + >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") + >>> + >>> def construct(self, input_x): + >>> out = self.fused_batch_norm_ex(input_x, self.scale, self.bias, self.mean, self.variance) + >>> return out + >>> >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) - >>> scale = Tensor(np.ones([64]), mindspore.float32) - >>> bias = Tensor(np.ones([64]), mindspore.float32) - >>> mean = Tensor(np.ones([64]), mindspore.float32) - >>> variance = Tensor(np.ones([64]), mindspore.float32) - >>> op = P.FusedBatchNormEx() - >>> output = op(input_x, scale, bias, mean, variance) + >>> net = FusedBatchNormExNet() + >>> output = net(input_x) + >>> output[0].shape + (128, 64, 32, 64) """ __mindspore_signature__ = ( sig.make_sig('input_x', dtype=sig.sig_dtype.T2), @@ -756,7 +790,7 @@ class BNTrainingReduce(PrimitiveWithInfer): Examples: >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) - >>> bn_training_reduce = P.BNTrainingReduce(input_x) + >>> bn_training_reduce = P.BNTrainingReduce() >>> output = bn_training_reduce(input_x) """ @@ -5662,13 +5696,30 @@ class DynamicRNN(PrimitiveWithInfer): Has the same type with input `b`. Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> import mindspore.context as context + >>> context.set_context(mode=context.GRAPH_MODE) + >>> class DynamicRNNNet(nn.Cell): + >>> def __init__(self): + >>> super(DynamicRNNNet, self).__init__() + >>> self.dynamic_rnn = P.DynamicRNN() + >>> + >>> def construct(self, x, w, b, init_h, init_c): + >>> out = self.dynamic_rnn(x, w, b, None, init_h, init_c) + >>> return out + >>> >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) >>> w = Tensor(np.random.rand(96, 128).astype(np.float16)) >>> b = Tensor(np.random.rand(128).astype(np.float16)) >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) - >>> dynamic_rnn = P.DynamicRNN() - >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) + >>> net = DynamicRNNNet() + >>> output = net(x, w, b, init_h, init_c) >>> output[0].shape (2, 16, 32) """ diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 0279bb19bd..5706405cee 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1446,7 +1446,7 @@ test_case_math_ops = [ 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], 'desc_bprop': [[2, 3, 4, 5]]}), ('IFMR', { - 'block': P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), + 'block': Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), search_step=1.0, with_offset=False), 'desc_inputs': [[3, 4, 5], Tensor([0.1], mstype.float32), Tensor([0.9], mstype.float32), Tensor(np.random.rand(4).astype(np.int32))],