diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 8ad9957f8b..870e55816e 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -31,3 +31,4 @@ from .dropout_grad import expand_dropoutgrad from .layernorm_grad import expand_layernormgrad from .logsoftmax import expand_logsoftmax from .logsoftmax_grad import expand_logsoftmaxgrad +from .gkdropout import expand_gkdropout diff --git a/mindspore/_extends/graph_kernel/expanders/gkdropout.py b/mindspore/_extends/graph_kernel/expanders/gkdropout.py new file mode 100644 index 0000000000..f340087131 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/gkdropout.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for GkDropOut""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_gkdropout(expand_info): + """GkDropOut expander""" + # get op info. + input_desc = expand_info['input_desc'][0] + maks_desc = expand_info['input_desc'][1] + keep_prob = None + for attr in expand_info['attr']: + if 'keep_prob' in attr: + keep_prob = attr['keep_prob'] + if keep_prob is None: + raise RuntimeError("keep_prob does not exist in attrs.") + # generate a graph. + graph_builder = builder.GraphBuilder() + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) + input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) + graph_scope.set_input(input_x, input_mask) + keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat") + r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat") + + mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) + mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) + + # compute result + result = graph_builder.emit('Mul', [r_keep_prob, input_x]) + result = graph_builder.emit('Mul', [result, mask]) + # set graph output. + graph_scope.set_output(result, mask) + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc index e674bef5ad..7c2cbbe26c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc @@ -29,5 +29,7 @@ MS_REG_GPU_KERNEL_ONE(UniformInt, RandomOpGpuKernel, int) MS_REG_GPU_KERNEL_ONE(UniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RandomOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(CudnnUniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + RandomOpGpuKernel, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index 1f05be71a4..47a9773715 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -25,13 +25,22 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh" +#include "include/curand.h" namespace mindspore { namespace kernel { -enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_INT, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; +enum RandomOptype { + RANDOM_OP_NORMAL = 0, + RANDOM_OP_UNIFORM_INT, + RANDOM_OP_UNIFORM_REAL, + RANDOM_OP_CUDNN_UNIFORM_REAL, + RANDOM_OP_INVALID_TYPE = 255 +}; -const std::map kRandomOpTypeMap = { - {"StandardNormal", RANDOM_OP_NORMAL}, {"UniformInt", RANDOM_OP_UNIFORM_INT}, {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; +const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}, + {"UniformInt", RANDOM_OP_UNIFORM_INT}, + {"UniformReal", RANDOM_OP_UNIFORM_REAL}, + {"CudnnUniformReal", RANDOM_OP_CUDNN_UNIFORM_REAL}}; template class RandomOpGpuKernel : public GpuKernel { @@ -76,6 +85,23 @@ class RandomOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } + case RANDOM_OP_CUDNN_UNIFORM_REAL: { + float *mask_f = GetDeviceAddress(outputs, 0); + if (!states_init_) { + CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT), + "Failed to create generator"); + CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_), + "Failed to SetPseudoRandomGeneratorSeed"); + MS_EXCEPTION_IF_NULL(mask_generator_); + states_init_ = true; + } + CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast(stream_ptr)), + "Failed to set stream for generator"); + // curandGen only support float or double for mask. + CHECK_CURAND_RET_WITH_EXCEPT(curandGenerateUniform(mask_generator_, mask_f, outputs[0]->size / sizeof(float)), + "Failed to generate uniform"); + break; + } default: { MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; } @@ -148,6 +174,8 @@ class RandomOpGpuKernel : public GpuKernel { std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; + curandGenerator_t mask_generator_; + bool states_init_{false}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index d496bb42ab..23f9d3bd82 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -44,6 +44,7 @@ constexpr size_t kMulInputNum = 3; constexpr size_t kRsqrtInputNum = 2; constexpr size_t kSubInputNum = 3; constexpr size_t kAssignSubInputNum = 3; +constexpr size_t kDropoutInputNum = 2; constexpr size_t kConvBn1OutputNum = 3; constexpr size_t kBn2ReluOutputNum = 4; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 716bf7d320..a899e1749e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -25,6 +25,7 @@ #include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/substitute_dropout.h" #include "backend/session/anf_runtime_algorithm.h" #include "mindspore/core/ir/graph_utils.h" #include "pipeline/jit/parse/python_adapter.h" @@ -242,6 +243,10 @@ void GraphKernelExpander::ToPrimitive(const FuncGraphPtr &func_graph) const { bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { expand_ops_ = GetExpandOps(); MS_EXCEPTION_IF_NULL(func_graph); + if (expand_ops_.count(prim::kPrimGkDropout) > 0) { + std::shared_ptr pass = std::make_shared(); + pass->Run(func_graph); + } auto mng = func_graph->manager(); if (mng == nullptr) { mng = Manage(func_graph, true); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 25f6dc002a..9d42b616e9 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -711,7 +711,8 @@ std::unordered_set GetExpandOps() { prim::kPrimTanhGrad, prim::kPrimReduceMean, prim::kPrimMaximumGrad, - prim::kPrimMinimumGrad + prim::kPrimMinimumGrad, + prim::kPrimGkDropout #endif }; return expand_ops; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index b75f8efc2c..2fda42c07c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -26,11 +26,15 @@ #include #include "ir/anf.h" #include "ir/func_graph.h" +#include "ir/primitive.h" #include "backend/session/kernel_graph.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include namespace mindspore { +namespace prim { +inline const PrimitivePtr kPrimGkDropout = std::make_shared("GkDropout"); +} // namespace prim namespace opt { using kernel::DumpOption; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc new file mode 100644 index 0000000000..9393601c0f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/substitute_dropout.h" + +#include +#include +#include +#include + +#include "base/core_ops.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/tensor.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +unsigned int SubstituteDropout::seed_ = time(NULL); + +const BaseRef SubstituteDropout::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimDropout, Xs}); +} + +void SetNewKernelInfo(const CNodePtr &kernel_node) { + std::vector inputs_format; + std::vector inputs_type; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + std::string origin_data_format = kOpFormat_DEFAULT; + auto cnode_info_builder = std::make_shared(); + cnode_info_builder->SetOriginDataFormat(origin_data_format); + cnode_info_builder->SetInputsFormat(inputs_format); + cnode_info_builder->SetInputsDeviceType(inputs_type); + cnode_info_builder->SetOutputsFormat(outputs_format); + cnode_info_builder->SetOutputsDeviceType(outputs_type); + cnode_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE); + cnode_info_builder->SetProcessor(kernel::Processor::CUDA); + auto cnode_selected_info = cnode_info_builder->Build(); + AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get()); +} + +const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < kDropoutInputNum) { + MS_LOG(EXCEPTION) << "Dropout's input num is wrong"; + } + AbstractBasePtr old_abstract = cnode->abstract()->Clone(); + auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0); + ShapeVector shape_i64; + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); + + // Create new tensor + AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)}; + auto tensor = std::make_shared(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), + static_cast(&shape[0]), kNumberTypeInt64); + uniform_input.push_back(NewValueNode(tensor)); + uniform_input[1]->set_abstract(tensor->ToAbstract()); + uniform_input[1]->set_kernel_info(std::make_shared()); + std::string origin_data_format = kOpFormat_DEFAULT; + std::vector outputs_format = {origin_data_format}; + std::vector outputs_type = {kNumberTypeInt32}; + auto tensor_info_builder = std::make_shared(); + tensor_info_builder->SetOriginDataFormat(origin_data_format); + tensor_info_builder->SetOutputsFormat(outputs_format); + tensor_info_builder->SetOutputsDeviceType(outputs_type); + tensor_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE); + tensor_info_builder->SetProcessor(kernel::Processor::CUDA); + auto tensor_selected_info = tensor_info_builder->Build(); + AnfAlgo::SetSelectKernelBuildInfo(tensor_selected_info, uniform_input[1].get()); + + // create new uniform_real_node + auto uniform_real_node = func_graph->NewCNode(uniform_input); + AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_)))); + AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_)))); + auto uniform_abstract = std::make_shared(std::make_shared(32), shape_i64); + uniform_real_node->set_abstract(uniform_abstract); + uniform_real_node->set_kernel_info(std::make_shared()); + SetNewKernelInfo(uniform_real_node); + + // create new_node, has two input, first is cnode->input[1], second is unifom_real_node + AnfNodePtrList new_node_inputs = {NewValueNode(prim::kPrimGkDropout)}; + new_node_inputs.push_back(cnode->input(1)); + new_node_inputs.push_back(uniform_real_node); + auto new_node = func_graph->NewCNode(new_node_inputs); + AnfAlgo::GetCNodePrimitive(new_node)->set_attr("keep_prob", AnfAlgo::GetCNodePrimitive(cnode)->GetAttr("keep_prob")); + new_node->set_abstract(old_abstract); + new_node->set_kernel_info(std::make_shared()); + SetNewKernelInfo(new_node); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h new file mode 100644 index 0000000000..33804e77a4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SubstituteDropout : public PatternProcessPass { + public: + explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {} + ~SubstituteDropout() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + static unsigned int seed_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 630e1bb7a7..499e023d7e 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -164,6 +164,9 @@ inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("DropoutGenMask"); inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); inline const PrimitivePtr kPrimDropoutGrad = std::make_shared("DropoutGrad"); +inline const PrimitivePtr kPrimDropout = std::make_shared("Dropout"); +inline const PrimitivePtr kPrimUniformReal = std::make_shared("UniformReal"); +inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared("CudnnUniformReal"); inline const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); inline const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); inline const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 618c8e1ebd..9ab3ef39a8 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -118,7 +118,6 @@ class StandardLaplace(PrimitiveWithInfer): return out - class Gamma(PrimitiveWithInfer): r""" Produces random positive floating-point values x, distributed according to probability density function: @@ -532,6 +531,7 @@ class Multinomial(PrimitiveWithInfer): "value": None} return out + class UniformCandidateSampler(PrimitiveWithInfer): r""" Uniform candidate sampler. diff --git a/tests/st/ops/graph_kernel/test_dropout.py b/tests/st/ops/graph_kernel/test_dropout.py new file mode 100644 index 0000000000..7e6576bb4c --- /dev/null +++ b/tests/st/ops/graph_kernel/test_dropout.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, keep_prob): + super(Net, self).__init__() + self.drop = P.Dropout(keep_prob) + + def construct(self, x_): + return self.drop(x_) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dropout(): + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + x_shape = [4096, 768] + x = np.ones(x_shape).astype(np.float32) + keep_prob = 0.9 + dropout = Net(keep_prob) + tx = Tensor(x) + output, mask = dropout(tx) + + output_np = output.asnumpy() + elem_count = x.size + nonzero_count = np.count_nonzero(output_np) + assert (elem_count * (keep_prob - 0.1)) < nonzero_count < (elem_count * (keep_prob + 0.1)) + output_sum = np.sum(output_np) + x_sum = np.sum(x) + assert abs(output_sum - x_sum)/x_sum < 0.1 + # check mask + mask_np = mask.asnumpy() + mask_sum = np.sum(mask_np) + assert np.count_nonzero(mask_np) == nonzero_count + assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1