modify nn annotation

pull/14102/head
changzherui 4 years ago
parent a76ea6f828
commit 6c51aa3fdd

@ -259,8 +259,13 @@ class CellList(_CellListBase, Cell):
>>> cell_ls.append(relu)
>>> cell_ls
CellList<
(0): Conv2d<input_channels=100, ..., bias_init=None>
(1): BatchNorm2d<num_features=20, ..., moving_variance=Parameter (name=variance)>
(0): Conv2d<input_channels=100, output_channels=20, kernel_size=(3, 3),stride=(1, 1), pad_mode=same,
padding=0, dilation=(1, 1), group=1, has_bias=Falseweight_init=normal, bias_init=zeros, format=NCHW>
(1): BatchNorm2d<num_features=20, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=1.gamma,
shape=(20,), dtype=Float32, requires_grad=True), beta=Parameter (name=1.beta, shape=(20,), dtype=Float32,
requires_grad=True), moving_mean=Parameter (name=1.moving_mean, shape=(20,), dtype=Float32,
requires_grad=False), moving_variance=Parameter (name=1.moving_variance, shape=(20,), dtype=Float32,
requires_grad=False)>
(2): ReLU<>
>
"""

@ -512,7 +512,7 @@ class MultiClassDiceLoss(_Loss):
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
[0.3283009]
0.3283009
"""
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
super(MultiClassDiceLoss, self).__init__()

Loading…
Cancel
Save