|
|
|
@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
|
from mindspore._checkparam import Rel, Validator
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
from .activation import get_activation
|
|
|
|
|
|
|
|
|
@ -146,33 +145,17 @@ class Dropout(Cell):
|
|
|
|
|
seed0, seed1 = _get_graph_seed(0, "dropout")
|
|
|
|
|
self.seed0 = seed0
|
|
|
|
|
self.seed1 = seed1
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
self.get_shape = P.Shape()
|
|
|
|
|
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
|
|
|
|
|
self.dropout_do_mask = P.DropoutDoMask()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.is_ascend = context.get_context('device_target') in ["Ascend"]
|
|
|
|
|
self.dropout = P.Dropout(keep_prob)
|
|
|
|
|
self.dropout = P.Dropout(keep_prob, seed0, seed1)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if not self.training:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
if not self.is_ascend:
|
|
|
|
|
out, _ = self.dropout(x)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
if self.keep_prob == 1:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
shape = self.get_shape(x)
|
|
|
|
|
dtype = P.DType()(x)
|
|
|
|
|
if _is_float_dtype(dtype):
|
|
|
|
|
keep_prob = self.cast(self.keep_prob, dtype)
|
|
|
|
|
else:
|
|
|
|
|
keep_prob = self.cast(self.keep_prob, mstype.float16)
|
|
|
|
|
output = self.dropout_gen_mask(shape, keep_prob)
|
|
|
|
|
return self.dropout_do_mask(x, output, keep_prob)
|
|
|
|
|
out, _ = self.dropout(x)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
|
|
|
|
|