|
|
|
@ -65,18 +65,22 @@ class Dropout(Cell):
|
|
|
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If `keep_prob` is not in range (0, 1).
|
|
|
|
|
ValueError: If `keep_prob` is not in range (0, 1].
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - An N-D Tensor.
|
|
|
|
|
- **input** (Tensor) - The input tensor.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, output tensor with the same shape as the input.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32)
|
|
|
|
|
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
|
|
|
|
>>> net = nn.Dropout(keep_prob=0.8)
|
|
|
|
|
>>> net(x)
|
|
|
|
|
[[[1.0, 1.0, 1.0],
|
|
|
|
|
[1.0, 1.0, 1.0]],
|
|
|
|
|
[[1.0, 1.0, 1.0],
|
|
|
|
|
[1.0, 1.0, 1.0]]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32):
|
|
|
|
@ -84,6 +88,7 @@ class Dropout(Cell):
|
|
|
|
|
if keep_prob <= 0 or keep_prob > 1:
|
|
|
|
|
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
|
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
|
self.seed0 = seed0
|
|
|
|
|
self.seed1 = seed1
|
|
|
|
@ -107,8 +112,7 @@ class Dropout(Cell):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
shape = self.get_shape(x)
|
|
|
|
|
dtype = P.DType()(x)
|
|
|
|
|
keep_prob = self.cast(self.keep_prob, dtype)
|
|
|
|
|
keep_prob = self.cast(self.keep_prob, mstype.float32)
|
|
|
|
|
output = self.dropout_gen_mask(shape, keep_prob)
|
|
|
|
|
return self.dropout_do_mask(x, output, keep_prob)
|
|
|
|
|
|
|
|
|
|