diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index aff62042f1..ba95566558 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -15,6 +15,7 @@ """basic""" import numpy as np import mindspore.common.dtype as mstype +from mindspore.common.seed import get_seed from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore._checkparam import check_int_positive, check_bool @@ -60,8 +61,6 @@ class Dropout(Cell): Args: keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9, dropping out 10% of input units. Default: 0.5. - seed0 (int): The first random seed. Default: 0. - seed1 (int): The second random seed. Default: 0. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32. Raises: @@ -83,18 +82,19 @@ class Dropout(Cell): [1.0, 1.0, 1.0]]] """ - def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32): + def __init__(self, keep_prob=0.5, dtype=mstype.float32): super(Dropout, self).__init__() if keep_prob <= 0 or keep_prob > 1: raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) self.keep_prob = keep_prob - self.seed0 = seed0 - self.seed1 = seed1 + seed0 = get_seed() + self.seed0 = seed0 if seed0 is not None else 0 + self.seed1 = 0 self.dtype = dtype self.get_shape = P.Shape() - self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1) + self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) self.dropout_do_mask = P.DropoutDoMask() self.cast = P.Cast() self.is_gpu = context.get_context('device_target') in ["GPU"] @@ -117,8 +117,7 @@ class Dropout(Cell): return self.dropout_do_mask(x, output, keep_prob) def extend_repr(self): - str_info = 'keep_prob={}, Seed0={}, Seed1={}, dtype={}' \ - .format(self.keep_prob, self.seed0, self.seed1, self.dtype) + str_info = 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) return str_info diff --git a/tests/ut/python/nn/test_dropout.py b/tests/ut/python/nn/test_dropout.py index 93cb9c81ed..0ac85144eb 100644 --- a/tests/ut/python/nn/test_dropout.py +++ b/tests/ut/python/nn/test_dropout.py @@ -26,7 +26,7 @@ context.set_context(device_target="Ascend") def test_check_dropout_3(): Tensor(np.ones([20, 16, 50]).astype(np.int32)) with pytest.raises(ValueError): - nn.Dropout(3, 0, 1) + nn.Dropout(3) class Net_dropout(nn.Cell): diff --git a/tests/ut/python/pynative_mode/nn/test_dropout.py b/tests/ut/python/pynative_mode/nn/test_dropout.py index 3272e92a51..6865baa018 100644 --- a/tests/ut/python/pynative_mode/nn/test_dropout.py +++ b/tests/ut/python/pynative_mode/nn/test_dropout.py @@ -23,24 +23,12 @@ from mindspore import dtype as mstype context.set_context(device_target="Ascend") -def test_check_dropout_1(): +def test_check_dropout(): x = Tensor(np.ones([20, 16, 50]), mstype.float32) m = nn.Dropout(0.8) m(x) -def test_check_dropout_2(): - x = Tensor(np.ones([20, 16, 50]), mstype.float32) - m = nn.Dropout(0.3, seed0=1) - m(x) - - -def test_check_dropout_3(): - x = Tensor(np.ones([20, 16, 50]), mstype.float32) - m = nn.Dropout(0.3, seed0=1, seed1=1) - m(x) - - class Net_Dropout(nn.Cell): def __init__(self): super(Net_Dropout, self).__init__()