!8655 add output of example of some operations.

From: @wangshuide2020
Reviewed-by: 
Signed-off-by:
pull/8655/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bcc6e1ca28

@ -147,8 +147,9 @@ class ELU(Cell):
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
>>> elu = nn.ELU()
>>> elu(input_x)
>>> result = elu(input_x)
>>> print(result)
[-0.63212055 -0.86466473 0 2 1]
"""
def __init__(self, alpha=1.0):
@ -446,8 +447,9 @@ class HSwish(Cell):
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> hswish = nn.HSwish()
>>> hswish(input_x)
>>> result = hswish(input_x)
>>> print(result)
[-0.3333 -0.3333 0 1.666 0.6665]
"""
def __init__(self):
@ -480,8 +482,9 @@ class HSigmoid(Cell):
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> hsigmoid = nn.HSigmoid()
>>> hsigmoid(input_x)
>>> result = hsigmoid(input_x)
>>> print(result)
[0.3333 0.1666 0.5 0.833 0.6665]
"""
def __init__(self):

@ -162,9 +162,13 @@ class EmbeddingLookup(Cell):
Examples:
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
>>> output.shape
(2, 2, 2)
>>> result = nn.EmbeddingLookup(4,2)(input_indices)
>>> print(result)
[[[ 0.00856617 0.01039034]
[ 0.00196276 -0.00094072]]
[[ 0.01279703 0.00078912]
[ 0.00084863 -0.00742412]]]
"""
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"

@ -290,9 +290,11 @@ class MSSSIM(Cell):
Examples:
>>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
>>> img1 = Tensor(np.random.random((1,3,128,128)))
>>> img2 = Tensor(np.random.random((1,3,128,128)))
>>> msssim = net(img1, img2)
>>> img1 = Tensor(np.random.random((1, 3, 128, 128)))
>>> img2 = Tensor(np.random.random((1, 3, 128, 128)))
>>> result = net(img1, img2)
>>> print(result)
[0.20930639]
"""
def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03):

@ -284,9 +284,13 @@ class BatchNorm1d(_BatchNorm):
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
Examples:
>>> net = nn.BatchNorm1d(num_features=16)
>>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
>>> net(input)
>>> net = nn.BatchNorm1d(num_features=4)
>>> input = Tensor(np.random.randint(0, 255, [3, 4]), mindspore.float32)
>>> result = net(input)
>>> print(result)
[[ 57.99971 50.99974 220.99889 222.99889 ]
[106.99947 193.99902 77.99961 101.99949 ]
[ 85.99957 188.99905 46.99976 226.99887 ]]
"""
def __init__(self,
@ -367,8 +371,23 @@ class BatchNorm2d(_BatchNorm):
Examples:
>>> net = nn.BatchNorm2d(num_features=3)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> net(input)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 4, 4]), mindspore.float32)
>>> result = net(input)
>>> print(result)
[[[[148.99925 148.99925 178.9991 77.99961 ]
[ 41.99979 97.99951 157.9992 94.99953 ]
[ 87.99956 158.9992 50.99974 179.9991 ]
[146.99927 27.99986 119.9994 253.99873 ]]
[[178.9991 187.99905 190.99904 88.99956 ]
[213.99893 158.9992 13.99993 200.999 ]
[224.99887 56.99971 246.99876 239.9988 ]
[ 97.99951 34.99983 28.99986 57.99971 ]]
[[ 14.99993 31.99984 136.99931 207.99896 ]
[180.9991 28.99986 23.99988 71.99964 ]
[112.99944 36.99981 213.99893 71.99964 ]
[ 8.99996 162.99919 157.9992 41.99979 ]]]]
"""
def __init__(self,

@ -88,7 +88,9 @@ class ExponentialDecayLR(LearningRateSchedule):
>>> decay_steps = 4
>>> global_step = Tensor(2, mstype.int32)
>>> exponential_decay_lr = ExponentialDecayLR(learning_rate, decay_rate, decay_steps)
>>> exponential_decay_lr(global_step)
>>> result = exponential_decay_lr(global_step)
>>> print(result)
0.09486833
"""
def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False):
super(ExponentialDecayLR, self).__init__()
@ -144,7 +146,9 @@ class NaturalExpDecayLR(LearningRateSchedule):
>>> decay_steps = 4
>>> global_step = Tensor(2, mstype.int32)
>>> natural_exp_decay_lr = NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True)
>>> natural_exp_decay_lr(global_step)
>>> result = natural_exp_decay_lr(global_step)
>>> print(result)
0.016529894
"""
def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False):
super(NaturalExpDecayLR, self).__init__()
@ -201,7 +205,9 @@ class InverseDecayLR(LearningRateSchedule):
>>> decay_steps = 4
>>> global_step = Tensor(2, mstype.int32)
>>> inverse_decay_lr = InverseDecayLR(learning_rate, decay_rate, decay_steps, True)
>>> inverse_decay_lr(global_step)
>>> result = inverse_decay_lr(global_step)
>>> print(result)
0.06896552
"""
def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False):
super(InverseDecayLR, self).__init__()
@ -247,7 +253,9 @@ class CosineDecayLR(LearningRateSchedule):
>>> decay_steps = 4
>>> global_steps = Tensor(2, mstype.int32)
>>> cosine_decay_lr = CosineDecayLR(min_lr, max_lr, decay_steps)
>>> cosine_decay_lr(global_steps)
>>> result = cosine_decay_lr(global_steps)
>>> print(result)
0.055
"""
def __init__(self, min_lr, max_lr, decay_steps):
super(CosineDecayLR, self).__init__()
@ -313,7 +321,9 @@ class PolynomialDecayLR(LearningRateSchedule):
>>> power = 0.5
>>> global_step = Tensor(2, mstype.int32)
>>> polynomial_decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
>>> polynomial_decay_lr(global_step)
>>> result = polynomial_decay_lr(global_step)
>>> print(result)
0.07363961
"""
def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False):
super(PolynomialDecayLR, self).__init__()
@ -379,7 +389,9 @@ class WarmUpLR(LearningRateSchedule):
>>> warmup_steps = 2
>>> global_step = Tensor(2, mstype.int32)
>>> warmup_lr = WarmUpLR(learning_rate, warmup_steps)
>>> warmup_lr(global_step)
>>> result = warmup_lr(global_step)
>>> print(result)
0.1
"""
def __init__(self, learning_rate, warmup_steps):
super(WarmUpLR, self).__init__()

@ -41,6 +41,8 @@ class Accuracy(EvaluationBase):
>>> metric.clear()
>>> metric.update(x, y)
>>> accuracy = metric.eval()
>>> print(accuracy)
0.66666666
"""
def __init__(self, eval_type='classification'):
super(Accuracy, self).__init__(eval_type)

@ -129,7 +129,9 @@ class F1(Fbeta):
>>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.F1()
>>> metric.update(x, y)
>>> f1 = metric.eval()
>>> result = metric.eval()
>>> print(result)
[0.66666667 0.66666667]
"""
def __init__(self):
super(F1, self).__init__(1.0)

@ -457,6 +457,8 @@ class HSwish(PrimitiveWithInfer):
>>> hswish = P.HSwish()
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> result = hswish(input_x)
>>> print(result)
[-0.3333 -0.3333 0 1.666 0.6665]
"""
@prim_attr_register
@ -530,6 +532,8 @@ class HSigmoid(PrimitiveWithInfer):
>>> hsigmoid = P.HSigmoid()
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> result = hsigmoid(input_x)
>>> print(result)
[0.3333 0.1666 0.5 0.833 0.6665]
"""
@prim_attr_register
@ -2755,6 +2759,8 @@ class Gelu(PrimitiveWithInfer):
>>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> gelu = P.Gelu()
>>> result = gelu(tensor)
>>> print(result)
[0.841192 1.9545976 2.9963627]
"""
@prim_attr_register

Loading…
Cancel
Save