delete seed0 and seed1 in Dropout

pull/5735/head
caozhou 4 years ago
parent d76ac7c6e8
commit 6509fa9023

@ -15,6 +15,7 @@
"""basic"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.seed import get_seed
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore._checkparam import check_int_positive, check_bool
@ -60,8 +61,6 @@ class Dropout(Cell):
Args:
keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9,
dropping out 10% of input units. Default: 0.5.
seed0 (int): The first random seed. Default: 0.
seed1 (int): The second random seed. Default: 0.
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
Raises:
@ -79,17 +78,18 @@ class Dropout(Cell):
>>> net(x)
"""
def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32):
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.keep_prob = keep_prob
self.seed0 = seed0
self.seed1 = seed1
seed0 = get_seed()
self.seed0 = seed0 if seed0 is not None else 0
self.seed1 = 0
self.dtype = dtype
self.get_shape = P.Shape()
self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1)
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast()
self.is_gpu = context.get_context('device_target') in ["GPU"]
@ -113,8 +113,7 @@ class Dropout(Cell):
return self.dropout_do_mask(x, output, keep_prob)
def extend_repr(self):
str_info = 'keep_prob={}, Seed0={}, Seed1={}, dtype={}' \
.format(self.keep_prob, self.seed0, self.seed1, self.dtype)
str_info = 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
return str_info

@ -26,7 +26,7 @@ context.set_context(device_target="Ascend")
def test_check_dropout_3():
Tensor(np.ones([20, 16, 50]).astype(np.int32))
with pytest.raises(ValueError):
nn.Dropout(3, 0, 1)
nn.Dropout(3)
class Net_dropout(nn.Cell):

@ -23,24 +23,12 @@ from mindspore import dtype as mstype
context.set_context(device_target="Ascend")
def test_check_dropout_1():
def test_check_dropout():
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.8)
m(x)
def test_check_dropout_2():
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.3, seed0=1)
m(x)
def test_check_dropout_3():
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.3, seed0=1, seed1=1)
m(x)
class Net_Dropout(nn.Cell):
def __init__(self):
super(Net_Dropout, self).__init__()

Loading…
Cancel
Save