update example of some operations.

pull/8850/head
wangshuide2020 4 years ago
parent 0b3aa904c0
commit f64a201804

@ -291,11 +291,12 @@ class MSSSIM(Cell):
Examples:
>>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
>>> img1 = Tensor(np.random.random((1,3,128,128)))
>>> img2 = Tensor(np.random.random((1,3,128,128)))
>>> np.random.seed(0)
>>> img1 = Tensor(np.random.random((1, 3, 128, 128)))
>>> img2 = Tensor(np.random.random((1, 3, 128, 128)))
>>> output = net(img1, img2)
>>> print(output)
[0.22965115]
[0.20607519]
"""
def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03):

@ -285,11 +285,12 @@ class BatchNorm1d(_BatchNorm):
Examples:
>>> net = nn.BatchNorm1d(num_features=4)
>>> np.random.seed(0)
>>> input = Tensor(np.random.randint(0, 255, [2, 4]), mindspore.float32)
>>> output = net(input)
>>> print(output)
[[210.99895 136.99931 89.99955 240.9988 ]
[ 87.99956 157.9992 89.99955 42.999786]]
[[171.99915 46.999763 116.99941 191.99904 ]
[ 66.999664 250.99875 194.99902 102.99948 ]]
"""
def __init__(self,
@ -370,15 +371,18 @@ class BatchNorm2d(_BatchNorm):
Examples:
>>> net = nn.BatchNorm2d(num_features=3)
>>> np.random.seed(0)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mindspore.float32)
>>> output = net(input)
>>> print(output)
[[[[128.99936 53.99973]
[191.99904 183.99908]]
[[146.99927 182.99908]
[184.99907 120.9994 ]]
[[ 33.99983 234.99883]
[188.99905 11.99994]]]]
[[[[171.99915 46.999763 ]
[116.99941 191.99904 ]]
[[ 66.999664 250.99875 ]
[194.99902 102.99948 ]]
[[ 8.999955 210.99895 ]
[ 20.999895 241.9988 ]]]]
"""
def __init__(self,
@ -455,9 +459,34 @@ class GlobalBatchNorm(_BatchNorm):
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input)
>>> # This example should be run with multiple processes. Refer to the run_distribute_train.sh
>>> import os
>>> import numpy as np
>>> from mindspore.communication import init
>>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn, Tensor
>>> from mindspore.common import dtype as mstype
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id))
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>> np.random.seed(0)
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=2)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32)
>>> output = global_bn_op(input)
>>> print(output)
[[[[171.99915 46.999763]
[116.99941 191.99904 ]]
[[ 66.999664 250.99875 ]
[194.99902 102.99948 ]]
[[ 8.999955 210.99895 ]
[ 20.9999895 241.9988 ]]]]
"""
def __init__(self,

@ -248,6 +248,16 @@ class GetNextSingleOp(Cell):
queue_name (str): Queue name to fetch the data.
For detailed information, refer to `ops.operations.GetNext`.
Examples:
>>> # Refer to dataset_helper.py for detail usage.
>>> data_set = get_dataset()
>>> dataset_shapes = data_set.output_shapes()
>>> np_types = data_set.output_types()
>>> dataset_types = convert_type(dataset_shapes, np_types)
>>> queue_name = data_set.__TRANSFER_DATASET__.queue_name
>>> getnext_op = GetNextSingleOp(dataset_types, dataset_shapes, queue_name)
>>> getnext_op()
"""
def __init__(self, dataset_types, dataset_shapes, queue_name):

@ -246,16 +246,19 @@ class DistributedGradReducer(Cell):
ValueError: If degree is not a int or less than 0.
Examples:
>>> from mindspore.communication import init, get_group_size
>>> # This example should be run with multiple processes. Refer to the run_distribute_train.sh
>>> import os
>>> import numpy as np
>>> from mindspore.communication import init
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import Parameter, Tensor
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
>>> _get_parallel_mode)
>>> from mindspore.nn.wrap.cell_wrapper import WithLossCell
>>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean)
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
@ -295,12 +298,28 @@ class DistributedGradReducer(Cell):
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> class Net(nn.Cell):
>>> def __init__(self, in_features, out_features):
>>> super(Net, self).__init__()
>>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
>>> name='weight')
>>> self.matmul = P.MatMul()
>>>
>>> def construct(self, x):
>>> output = self.matmul(x, self.weight)
>>> return output
>>>
>>> size, in_features, out_features = 16, 16, 10
>>> network = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> net_with_loss = WithLossCell(network, loss)
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(net_with_loss, optimizer)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> grads = train_cell(inputs, label)
>>> print(grads)
256.0
"""
def __init__(self, parameters, mean=True, degree=None):

@ -76,16 +76,30 @@ class DynamicLossScaleUpdateCell(Cell):
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> from mindspore.nn.wrap.cell_wrapper import WithLossCell
>>>
>>> class Net(nn.Cell):
>>> def __init__(self, in_features, out_features):
>>> super(Net, self).__init__()
>>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
>>> name='weight')
>>> self.matmul = P.MatMul()
>>>
>>> def construct(self, x):
>>> output = self.matmul(x, self.weight)
>>> return output
>>>
>>> in_features, out_features = 16, 10
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> train_network.set_train()
>>>
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scale_sense=scaling_sens)
"""
def __init__(self,
@ -142,16 +156,30 @@ class FixedLossScaleUpdateCell(Cell):
loss_scale_value (float): Initializes loss scale.
Examples:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> from mindspore.nn.wrap.cell_wrapper import WithLossCell
>>>
>>> class Net(nn.Cell):
>>> def __init__(self, in_features, out_features):
>>> super(Net, self).__init__()
>>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
>>> name='weight')
>>> self.matmul = P.MatMul()
>>>
>>> def construct(self, x):
>>> output = self.matmul(x, self.weight)
>>> return output
>>>
>>> in_features, out_features = 16, 10
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> train_network.set_train()
>>>
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scale_sense=scaling_sens)
"""
def __init__(self, loss_scale_value):
@ -193,21 +221,45 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
- **loss scaling value** (Tensor) - Tensor with shape :math:`()`
Examples:
>>> #1) when the type scale_sense is Cell:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> from mindspore.nn.wrap.cell_wrapper import WithLossCell
>>> from mindspore.common import dtype as mstype
>>>
>>> class Net(nn.Cell):
>>> def __init__(self, in_features, out_features):
>>> super(Net, self).__init__()
>>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
>>> name='weight')
>>> self.matmul = P.MatMul()
>>>
>>> def construct(self, x):
>>> output = self.matmul(x, self.weight)
>>> return output
>>>
>>> size, in_features, out_features = 16, 16, 10
>>> #1) when the type of scale_sense is Cell:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> train_network.set_train()
>>>
>>> #2) when the type scale_sense is Tensor:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> #2) when the type of scale_sense is Tensor:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> output = train_network(inputs, label)
>>> print(output[0])
256.0
"""
def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)

Loading…
Cancel
Save