为mindspore.ops.operations包和mindspore.nn包添加Examples

pull/1223/head
zhouneng 5 years ago
parent 4ecc9389e0
commit 3cc750fdce

@ -62,6 +62,10 @@ class FakeQuantWithMinMax(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = nn.FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
def __init__(self,
@ -182,6 +186,12 @@ class Conv2dBatchNormQuant(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> batchnorm_quant = nn.Conv2dBatchNormQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid",
>>> dilation=(1, 1))
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
>>> result = batchnorm_quant(input_x)
"""
def __init__(self,
@ -339,6 +349,11 @@ class Conv2dQuant(_Conv):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid",
>>> dilation=(1, 1))
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
>>> result = conv2d_quant(input_x)
"""
def __init__(self,
@ -412,6 +427,11 @@ class DenseQuant(Cell):
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> dense_quant = nn.DenseQuant(3, 6)
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
>>> result = dense_quant(input_x)
"""
def __init__(
@ -503,6 +523,10 @@ class ReLUQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> relu_quant = nn.ReLUQuant()
>>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32)
>>> result = relu_quant(input_x)
"""
def __init__(self,
@ -546,6 +570,10 @@ class ReLU6Quant(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> relu6_quant = nn.ReLU6Quant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result = relu6_quant(input_x)
"""
def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
@ -584,6 +612,10 @@ class HSwishQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> hswish_quant = nn.HSwishQuant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = hswish_quant(input_x)
"""
def __init__(self,
@ -633,6 +665,10 @@ class HSigmoidQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> hsigmoid_quant = nn.HSigmoidQuant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = hsigmoid_quant(input_x)
"""
def __init__(self,
@ -682,6 +718,11 @@ class TensorAddQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> add_quant = nn.TensorAddQuant()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
>>> result = add_quant(input_x, input_y)
"""
def __init__(self,

@ -98,7 +98,17 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
"""Performs grad of FakeQuantWithMinMax operation."""
r"""
Performs grad of FakeQuantWithMinMax operation.
Examples:
>>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad()
>>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
>>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
>>> _min = Tensor(np.array([-4]), mindspore.float32)
>>> _max = Tensor(np.array([2]), mindspore.float32)
>>> result = fake_min_max_grad(dout, input_x, _min, _max)
"""
support_quant_bit = [4, 8]
@prim_attr_register
@ -149,10 +159,11 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
- Tensor, has the same type as input.
Examples:
>>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32)
>>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
>>> fake_quant = P.FakeQuantWithMinMaxPerChannel()
>>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
>>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit = [4, 8]
channel_idx = 0
@ -190,7 +201,17 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
"""Performs grad of FakeQuantWithMinMaxPerChannel operation."""
r"""
Performs grad of FakeQuantWithMinMaxPerChannel operation.
Examples:
>>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad()
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
>>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
>>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
>>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
>>> result = fqmmpc_grad(dout, input_x, _min, _max)
"""
support_quant_bit = [4, 8]
@prim_attr_register
@ -243,6 +264,13 @@ class BatchNormFold(PrimitiveWithInfer):
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
Examples:
>>> batch_norm_fold = P.BatchNormFold()
>>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32)
>>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32)
>>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32)
>>> global_step = Tensor(np.arange(6), mindspore.int32)
>>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
"""
channel = 1
@ -273,7 +301,19 @@ class BatchNormFold(PrimitiveWithInfer):
class BatchNormFoldGrad(PrimitiveWithInfer):
"""Performs grad of BatchNormFold operation."""
r"""
Performs grad of BatchNormFold operation.
Examples:
>>> batch_norm_fold_grad = P.BatchNormFoldGrad()
>>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32)
>>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32)
>>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32)
>>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32)
>>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32)
>>> global_step = Tensor([2], mindspore.int32)
>>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
"""
channel = 1
@prim_attr_register
@ -321,6 +361,12 @@ class CorrectionMul(PrimitiveWithInfer):
Outputs:
- **out** (Tensor) - Tensor has the same shape as x.
Examples:
>>> correction_mul = P.CorrectionMul()
>>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32)
>>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32)
>>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
>>> out = correction_mul(input_x, batch_std, running_std)
"""
channel = 0
@ -343,7 +389,17 @@ class CorrectionMul(PrimitiveWithInfer):
class CorrectionMulGrad(PrimitiveWithInfer):
"""Performs grad of CorrectionMul operation."""
r"""
Performs grad of CorrectionMul operation.
Examples:
>>> correction_mul_grad = P.CorrectionMulGrad()
>>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
>>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
>>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
>>> result = correction_mul_grad(dout, input_x, gamma, running_std)
"""
channel = 0
@prim_attr_register
@ -385,6 +441,18 @@ class BatchNormFold2(PrimitiveWithInfer):
Outputs:
- **y** (Tensor) - Tensor has the same shape as x.
Examples:
>>> batch_norm_fold2 = P.BatchNormFold2()
>>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32)
>>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32)
>>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32)
>>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32)
>>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32)
>>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32)
>>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32)
>>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32)
>>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
>>> running_std, running_mean, global_step)
"""
channel = 1
@ -418,7 +486,21 @@ class BatchNormFold2(PrimitiveWithInfer):
class BatchNormFold2Grad(PrimitiveWithInfer):
"""Performs grad of CorrectionAddGrad operation."""
r"""
Performs grad of CorrectionAddGrad operation.
Examples:
>>> bnf2_grad = P.BatchNormFold2Grad()
>>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32)
>>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32)
>>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32)
>>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32)
>>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32)
>>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32)
>>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32)
>>> global_step = Tensor(np.array([-2]), mindspore.int32)
>>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
"""
channel = 1
@prim_attr_register

@ -1156,6 +1156,16 @@ class Tile(PrimitiveWithInfer):
Such as set the shape of `input_x` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
then the shape of their corresponding positions can be multiplied, and
the shape of Outputs is :math:`(1*y_1, ..., x_S*y_R)`.
Examples:
>>> tile = P.Tile()
>>> input_x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
>>> multiples = (2, 3)
>>> result = tile(input_x, multiples)
[[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]
[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]]
"""
@prim_attr_register

@ -144,6 +144,12 @@ class Merge(PrimitiveWithInfer):
Outputs:
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
Examples:
>>> merge = P.Merge()
>>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32)
>>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32)
>>> result = merge((input_x, input_y))
"""
@prim_attr_register

@ -713,6 +713,12 @@ class Neg(PrimitiveWithInfer):
Outputs:
Tensor, has the same shape and dtype as input.
Examples:
>>> neg = P.Neg()
>>> input_x = Tensor(np.array([1, 2, -1, 2, 0, -3.5]), mindspore.float32)
>>> result = neg(input_x)
[-1. -2. 1. -2. 0. 3.5]
"""
@prim_attr_register
@ -1623,6 +1629,7 @@ class LogicalOr(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class IsNan(PrimitiveWithInfer):
"""
Judging which elements are nan for each position
@ -1632,6 +1639,11 @@ class IsNan(PrimitiveWithInfer):
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
Examples:
>>> is_nan = P.IsNan()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_nan(input_x)
"""
@prim_attr_register
@ -1645,6 +1657,7 @@ class IsNan(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
return mstype.bool_
class IsInf(PrimitiveWithInfer):
"""
Judging which elements are inf or -inf for each position
@ -1654,6 +1667,11 @@ class IsInf(PrimitiveWithInfer):
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
Examples:
>>> is_inf = P.IsInf()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_inf(input_x)
"""
@prim_attr_register
@ -1667,6 +1685,7 @@ class IsInf(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
return mstype.bool_
class IsFinite(PrimitiveWithInfer):
"""
Judging which elements are finite for each position
@ -1676,6 +1695,12 @@ class IsFinite(PrimitiveWithInfer):
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
Examples:
>>> is_finite = P.IsFinite()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = is_finite(input_x)
[False True False]
"""
@prim_attr_register
@ -1691,6 +1716,7 @@ class IsFinite(PrimitiveWithInfer):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_
class FloatStatus(PrimitiveWithInfer):
"""
Determine if the elements contains nan, inf or -inf. `0` for normal, `1` for overflow.
@ -1701,6 +1727,11 @@ class FloatStatus(PrimitiveWithInfer):
Outputs:
Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or
`mindspore.dtype.float16`.
Examples:
>>> float_status = P.FloatStatus()
>>> input_x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
>>> result = float_status(input_x)
"""
@prim_attr_register
@ -1714,6 +1745,7 @@ class FloatStatus(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
return x_dtype
class NPUAllocFloatStatus(PrimitiveWithInfer):
"""
Allocates a flag to store the overflow status.

File diff suppressed because it is too large Load Diff

@ -168,6 +168,26 @@ class CheckValid(PrimitiveWithInfer):
Outputs:
Tensor, the valided tensor.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.check_valid = P.CheckValid()
>>> def construct(self, x, y):
>>> valid_result = self.check_valid(x, y)
>>> return valid_result
>>>
>>> bboxes = Tensor(np.linspace(0, 6, 12).reshape(3, 4), mindspore.float32)
>>> img_metas = Tensor(np.array([2, 1, 3]), mindspore.float32)
>>> net = Net()
>>> result = net(bboxes, img_metas)
[True False False]
"""
@prim_attr_register

Loading…
Cancel
Save