add dynamic shape support and testcases to GPU Dropout

pull/11560/head
TFBunny 4 years ago
parent c37aa71009
commit 2fc5ebd077

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,14 +28,7 @@ namespace kernel {
template <typename T>
class DropoutGpuFwdKernel : public GpuKernel {
public:
DropoutGpuFwdKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),
num_count_(0),
keep_prob_(0.0),
states_init_(false),
mask_generator_(nullptr) {}
DropoutGpuFwdKernel() { ResetResource(); }
~DropoutGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -96,6 +89,18 @@ class DropoutGpuFwdKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
is_null_input_ = false;
num_count_ = 0;
keep_prob_ = 0.0;
states_init_ = false;
mask_generator_ = nullptr;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }

@ -85,6 +85,8 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -573,6 +573,23 @@ AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const Primitiv
return std::make_shared<AbstractTuple>(args_list);
}
AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector shape = x->shape()->shape();
ShapeVector min_shape = x->shape()->min_shape();
ShapeVector max_shape = x->shape()->max_shape();
(void)CheckMinMaxShape(shape, &min_shape, &max_shape);
auto output_shape =
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
AbstractBasePtrList ret = {output_shape, output_shape};
return std::make_shared<AbstractTuple>(ret);
}
AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple and a tensor.

@ -123,6 +123,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropout, {InferImplDropout, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},

@ -6336,7 +6336,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
return var_dtype, accum_dtype, linear_dtype
class Dropout(PrimitiveWithInfer):
class Dropout(PrimitiveWithCheck):
"""
During training, randomly zeroes some of the elements of the input tensor with probability.
@ -6367,15 +6367,12 @@ class Dropout(PrimitiveWithInfer):
self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
def infer_shape(self, x_shape):
def check_shape(self, x_shape):
validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name)
mask_shape = x_shape
return x_shape, mask_shape
def infer_dtype(self, x_dtype):
def check_dtype(self, x_dtype):
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype, x_dtype
class Dropout3d(PrimitiveWithInfer):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,8 +17,9 @@ import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
class Net(nn.Cell):
def __init__(self, keep_prob):
@ -52,3 +53,47 @@ def test_dropout():
mask_sum = np.sum(mask_np)
assert np.count_nonzero(mask_np) == nonzero_count
assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1
class DropoutDynamic(nn.Cell):
def __init__(self, keep_prob):
super(DropoutDynamic, self).__init__()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.drop = P.Dropout(keep_prob)
def construct(self, x):
x = self.test_dynamic(x)
return self.drop(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_1 = np.ones([32, 16, 2, 5]).astype(np.float32)
x_2 = np.ones([32, 16, 2, 5, 6]).astype(np.float32)
keep_prob = 0.4
net = DropoutDynamic(keep_prob)
output_1, mask_1 = net(Tensor(x_1))
elem_count_1 = x_1.size
nonzero_count_1 = np.count_nonzero(output_1.asnumpy())
assert (elem_count_1 * (keep_prob - 0.1)) < nonzero_count_1 < (elem_count_1 * (keep_prob + 0.1))
output_sum_1 = np.sum(output_1.asnumpy())
x_sum_1 = np.sum(x_1)
assert abs(output_sum_1 - x_sum_1)/x_sum_1 < 0.1
mask_sum_1 = np.sum(mask_1.asnumpy())
assert np.count_nonzero(mask_1.asnumpy()) == nonzero_count_1
assert abs(mask_sum_1 - nonzero_count_1)/nonzero_count_1 < 0.1
output_2, mask_2 = net(Tensor(x_2))
elem_count_2 = x_2.size
nonzero_count_2 = np.count_nonzero(output_2.asnumpy())
assert (elem_count_2 * (keep_prob - 0.1)) < nonzero_count_2 < (elem_count_2 * (keep_prob + 0.1))
output_sum_2 = np.sum(output_2.asnumpy())
x_sum_2 = np.sum(x_2)
assert abs(output_sum_2 - x_sum_2)/x_sum_2 < 0.1
mask_sum_2 = np.sum(mask_2.asnumpy())
assert np.count_nonzero(mask_2.asnumpy()) == nonzero_count_2
assert abs(mask_sum_2 - nonzero_count_2)/nonzero_count_2 < 0.1

Loading…
Cancel
Save