add dropout primtive

pull/1034/head
chenzomi 5 years ago
parent 3d3b9d5474
commit 661f9dfaf8

@ -25,6 +25,7 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import context
from ..cell import Cell from ..cell import Cell
from .activation import get_activation from .activation import get_activation
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@ -84,8 +85,19 @@ class Dropout(Cell):
self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1) self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1)
self.dropout_do_mask = P.DropoutDoMask() self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast() self.cast = P.Cast()
self.is_gpu = context.get_context('device_target') in ["GPU"]
if self.is_gpu:
self.dropout = P.Dropout(keep_prob)
def construct(self, x): def construct(self, x):
if not self.training:
return x
if self.is_gpu:
out, _ = self.dropout(x)
return out
shape = self.get_shape(x) shape = self.get_shape(x)
dtype = P.DType()(x) dtype = P.DType()(x)
keep_prob = self.cast(self.keep_prob, dtype) keep_prob = self.cast(self.keep_prob, dtype)

@ -643,3 +643,17 @@ def get_bprop_binary_cross_entropy(self):
return dx, zeros_like(y), zeros_like(weight) return dx, zeros_like(y), zeros_like(weight)
return bprop return bprop
@bprop_getters.register(P.Dropout)
def get_bprop_dropout(self):
"""Grad definition for `Dropout` operation."""
grad = P.DropoutGrad(self.drop_prob)
def bprop(x, out, dout):
_, mask = out
dy, _ = dout
dx = grad(dy, mask)
return (dx,)
return bprop

@ -52,7 +52,7 @@ from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D, BiasAdd, Conv2D,
DepthwiseConv2dNative, DepthwiseConv2dNative,
DropoutDoMask, DropoutDoMask, DropoutGrad, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, DropoutGenMask, Flatten, FusedBatchNorm,
Gelu, Elu, Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, GetNext, L2Normalize, LayerNorm, L2Loss,
@ -157,6 +157,8 @@ __all__ = [
'Shape', 'Shape',
'DropoutDoMask', 'DropoutDoMask',
'DropoutGenMask', 'DropoutGenMask',
'DropoutGrad',
'Dropout',
'Neg', 'Neg',
'Slice', 'Slice',
'DType', 'DType',

@ -2762,3 +2762,68 @@ class ConfusionMulGrad(PrimitiveWithInfer):
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
return input0_dtype, input1_dtype return input0_dtype, input1_dtype
class Dropout(PrimitiveWithInfer):
"""
During training, randomly zeroes some of the elements of the input tensor with probability.
Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
Outputs:
Tensor, the value of generated mask for input shape.
Examples:
>>> dropout = P.Dropout(drop_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout(in)
"""
@prim_attr_register
def __init__(self, drop_prob=0):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
mask_shape = x_shape
return x_shape, mask_shape
def infer_dtype(self, x_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
return x_dtype, x_dtype
class DropoutGrad(PrimitiveWithInfer):
"""
The gradient of Dropout. During training, randomly zeroes some of the elements
of the input tensor with probability.
Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
Outputs:
Tensor, the value of generated mask for input shape.
Examples:
>>> dropout_grad = P.DropoutGrad(drop_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in)
"""
@prim_attr_register
def __init__(self, drop_prob=0):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
def infer_shape(self, dy_shape, mask_shape):
return dy_shape
def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
return dy_dtype

@ -17,7 +17,9 @@ import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context
context.set_context(device_target="Ascend")
def test_check_dropout_3(): def test_check_dropout_3():
Tensor(np.ones([20, 16, 50]).astype(np.int32)) Tensor(np.ones([20, 16, 50]).astype(np.int32))

@ -19,26 +19,26 @@ from mindspore.common.api import _executor
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import dtype as mstype from mindspore import dtype as mstype
from mindspore import context
context.set_context(device_target="Ascend")
def test_check_dropout_1(): def test_check_dropout_1():
x = Tensor(np.ones([20, 16, 50]), mstype.float32) x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.8) m = nn.Dropout(0.8)
with pytest.raises(NotImplementedError): m(x)
m(x)
def test_check_dropout_2(): def test_check_dropout_2():
x = Tensor(np.ones([20, 16, 50]), mstype.float32) x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.3, seed0=1) m = nn.Dropout(0.3, seed0=1)
with pytest.raises(NotImplementedError): m(x)
m(x)
def test_check_dropout_3(): def test_check_dropout_3():
x = Tensor(np.ones([20, 16, 50]), mstype.float32) x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.3, seed0=1, seed1=1) m = nn.Dropout(0.3, seed0=1, seed1=1)
with pytest.raises(NotImplementedError): m(x)
m(x)
class Net_Dropout(nn.Cell): class Net_Dropout(nn.Cell):

Loading…
Cancel
Save