|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
"""basic"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.common.seed import get_seed
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore._checkparam import check_int_positive, check_bool
|
|
|
|
@ -60,8 +61,6 @@ class Dropout(Cell):
|
|
|
|
|
Args:
|
|
|
|
|
keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9,
|
|
|
|
|
dropping out 10% of input units. Default: 0.5.
|
|
|
|
|
seed0 (int): The first random seed. Default: 0.
|
|
|
|
|
seed1 (int): The second random seed. Default: 0.
|
|
|
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
@ -79,17 +78,18 @@ class Dropout(Cell):
|
|
|
|
|
>>> net(x)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32):
|
|
|
|
|
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
|
|
|
|
super(Dropout, self).__init__()
|
|
|
|
|
if keep_prob <= 0 or keep_prob > 1:
|
|
|
|
|
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
|
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
|
self.seed0 = seed0
|
|
|
|
|
self.seed1 = seed1
|
|
|
|
|
seed0 = get_seed()
|
|
|
|
|
self.seed0 = seed0 if seed0 is not None else 0
|
|
|
|
|
self.seed1 = 0
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
self.get_shape = P.Shape()
|
|
|
|
|
self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1)
|
|
|
|
|
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
|
|
|
|
|
self.dropout_do_mask = P.DropoutDoMask()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.is_gpu = context.get_context('device_target') in ["GPU"]
|
|
|
|
@ -113,8 +113,7 @@ class Dropout(Cell):
|
|
|
|
|
return self.dropout_do_mask(x, output, keep_prob)
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
str_info = 'keep_prob={}, Seed0={}, Seed1={}, dtype={}' \
|
|
|
|
|
.format(self.keep_prob, self.seed0, self.seed1, self.dtype)
|
|
|
|
|
str_info = 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
|
|
|
|
|
return str_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|