|
|
|
@ -276,7 +276,7 @@ class GetNextSingleOp(Cell):
|
|
|
|
|
>>> relu = P.ReLU()
|
|
|
|
|
>>> result = relu(data).asnumpy()
|
|
|
|
|
>>> print(result.shape)
|
|
|
|
|
>>> (32, 1, 32, 32)
|
|
|
|
|
(32, 1, 32, 32)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset_types, dataset_shapes, queue_name):
|
|
|
|
@ -356,6 +356,7 @@ class WithEvalCell(Cell):
|
|
|
|
|
Args:
|
|
|
|
|
network (Cell): The network Cell.
|
|
|
|
|
loss_fn (Cell): The loss Cell.
|
|
|
|
|
add_cast_fp32 (bool): Adjust the data type to float32.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
|
|
@ -410,7 +411,7 @@ class ParameterUpdate(Cell):
|
|
|
|
|
>>> param = network.parameters_dict()['weight']
|
|
|
|
|
>>> update = nn.ParameterUpdate(param)
|
|
|
|
|
>>> update.phase = "update_param"
|
|
|
|
|
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32)
|
|
|
|
|
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
|
|
|
|
|
>>> network_updata = update(weight)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|