|
|
|
@ -141,21 +141,37 @@ class Dropout(Cell):
|
|
|
|
|
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
|
|
|
|
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
|
seed0, seed1 = _get_graph_seed(0, "dropout")
|
|
|
|
|
self.seed0 = seed0
|
|
|
|
|
self.seed1 = seed1
|
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
|
self.dropout = P.Dropout(keep_prob, seed0, seed1)
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
self.get_shape = P.Shape()
|
|
|
|
|
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
|
|
|
|
|
self.dropout_do_mask = P.DropoutDoMask()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.is_ascend = context.get_context('device_target') in ["Ascend"]
|
|
|
|
|
self.dropout = P.Dropout(keep_prob)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if not self.training:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
if not self.is_ascend:
|
|
|
|
|
out, _ = self.dropout(x)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
if self.keep_prob == 1:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
out, _ = self.dropout(x)
|
|
|
|
|
return out
|
|
|
|
|
shape = self.get_shape(x)
|
|
|
|
|
dtype = P.DType()(x)
|
|
|
|
|
if _is_float_dtype(dtype):
|
|
|
|
|
keep_prob = self.cast(self.keep_prob, dtype)
|
|
|
|
|
else:
|
|
|
|
|
keep_prob = self.cast(self.keep_prob, mstype.float16)
|
|
|
|
|
output = self.dropout_gen_mask(shape, keep_prob)
|
|
|
|
|
return self.dropout_do_mask(x, output, keep_prob)
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
|
|
|
|
|