substitute dropout by cudnnuniformreal and dropout

pull/8994/head
zengzitao 4 years ago
parent 08dc1481c7
commit 3ef0e9f053

@ -31,3 +31,4 @@ from .dropout_grad import expand_dropoutgrad
from .layernorm_grad import expand_layernormgrad
from .logsoftmax import expand_logsoftmax
from .logsoftmax_grad import expand_logsoftmaxgrad
from .gkdropout import expand_gkdropout

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for GkDropOut"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_gkdropout(expand_info):
"""GkDropOut expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
maks_desc = expand_info['input_desc'][1]
keep_prob = None
for attr in expand_info['attr']:
if 'keep_prob' in attr:
keep_prob = attr['keep_prob']
if keep_prob is None:
raise RuntimeError("keep_prob does not exist in attrs.")
# generate a graph.
graph_builder = builder.GraphBuilder()
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format'])
graph_scope.set_input(input_x, input_mask)
keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat")
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat")
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v])
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
# compute result
result = graph_builder.emit('Mul', [r_keep_prob, input_x])
result = graph_builder.emit('Mul', [result, mask])
# set graph output.
graph_scope.set_output(result, mask)
graph = graph_builder.get()[0]
return graph

@ -29,5 +29,7 @@ MS_REG_GPU_KERNEL_ONE(UniformInt,
RandomOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(UniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
RandomOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(CudnnUniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
RandomOpGpuKernel, float)
} // namespace kernel
} // namespace mindspore

@ -25,13 +25,22 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh"
#include "include/curand.h"
namespace mindspore {
namespace kernel {
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_INT, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 };
enum RandomOptype {
RANDOM_OP_NORMAL = 0,
RANDOM_OP_UNIFORM_INT,
RANDOM_OP_UNIFORM_REAL,
RANDOM_OP_CUDNN_UNIFORM_REAL,
RANDOM_OP_INVALID_TYPE = 255
};
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {
{"StandardNormal", RANDOM_OP_NORMAL}, {"UniformInt", RANDOM_OP_UNIFORM_INT}, {"UniformReal", RANDOM_OP_UNIFORM_REAL}};
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL},
{"UniformInt", RANDOM_OP_UNIFORM_INT},
{"UniformReal", RANDOM_OP_UNIFORM_REAL},
{"CudnnUniformReal", RANDOM_OP_CUDNN_UNIFORM_REAL}};
template <typename T>
class RandomOpGpuKernel : public GpuKernel {
@ -76,6 +85,23 @@ class RandomOpGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case RANDOM_OP_CUDNN_UNIFORM_REAL: {
float *mask_f = GetDeviceAddress<float>(outputs, 0);
if (!states_init_) {
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"Failed to create generator");
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_),
"Failed to SetPseudoRandomGeneratorSeed");
MS_EXCEPTION_IF_NULL(mask_generator_);
states_init_ = true;
}
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to set stream for generator");
// curandGen only support float or double for mask.
CHECK_CURAND_RET_WITH_EXCEPT(curandGenerateUniform(mask_generator_, mask_f, outputs[0]->size / sizeof(float)),
"Failed to generate uniform");
break;
}
default: {
MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported.";
}
@ -148,6 +174,8 @@ class RandomOpGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
curandGenerator_t mask_generator_;
bool states_init_{false};
};
} // namespace kernel
} // namespace mindspore

@ -44,6 +44,7 @@ constexpr size_t kMulInputNum = 3;
constexpr size_t kRsqrtInputNum = 2;
constexpr size_t kSubInputNum = 3;
constexpr size_t kAssignSubInputNum = 3;
constexpr size_t kDropoutInputNum = 2;
constexpr size_t kConvBn1OutputNum = 3;
constexpr size_t kBn2ReluOutputNum = 4;

@ -25,6 +25,7 @@
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "mindspore/core/ir/graph_utils.h"
#include "pipeline/jit/parse/python_adapter.h"
@ -242,6 +243,10 @@ void GraphKernelExpander::ToPrimitive(const FuncGraphPtr &func_graph) const {
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
MS_EXCEPTION_IF_NULL(func_graph);
if (expand_ops_.count(prim::kPrimGkDropout) > 0) {
std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>();
pass->Run(func_graph);
}
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);

@ -711,7 +711,8 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimTanhGrad,
prim::kPrimReduceMean,
prim::kPrimMaximumGrad,
prim::kPrimMinimumGrad
prim::kPrimMinimumGrad,
prim::kPrimGkDropout
#endif
};
return expand_ops;

@ -26,11 +26,15 @@
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include <nlohmann/json.hpp>
namespace mindspore {
namespace prim {
inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout");
} // namespace prim
namespace opt {
using kernel::DumpOption;

@ -0,0 +1,120 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
#include <vector>
#include <string>
#include <algorithm>
#include <memory>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/tensor.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "runtime/device/kernel_info.h"
namespace mindspore {
namespace opt {
unsigned int SubstituteDropout::seed_ = time(NULL);
const BaseRef SubstituteDropout::DefinePattern() const {
VarPtr Xs = std::make_shared<Var>();
return VectorRef({prim::kPrimDropout, Xs});
}
void SetNewKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
std::string origin_data_format = kOpFormat_DEFAULT;
auto cnode_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
cnode_info_builder->SetOriginDataFormat(origin_data_format);
cnode_info_builder->SetInputsFormat(inputs_format);
cnode_info_builder->SetInputsDeviceType(inputs_type);
cnode_info_builder->SetOutputsFormat(outputs_format);
cnode_info_builder->SetOutputsDeviceType(outputs_type);
cnode_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE);
cnode_info_builder->SetProcessor(kernel::Processor::CUDA);
auto cnode_selected_info = cnode_info_builder->Build();
AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get());
}
const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kDropoutInputNum) {
MS_LOG(EXCEPTION) << "Dropout's input num is wrong";
}
AbstractBasePtr old_abstract = cnode->abstract()->Clone();
auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
ShapeVector shape_i64;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); });
// Create new tensor
AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)};
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())),
static_cast<void *>(&shape[0]), kNumberTypeInt64);
uniform_input.push_back(NewValueNode(tensor));
uniform_input[1]->set_abstract(tensor->ToAbstract());
uniform_input[1]->set_kernel_info(std::make_shared<device::KernelInfo>());
std::string origin_data_format = kOpFormat_DEFAULT;
std::vector<std::string> outputs_format = {origin_data_format};
std::vector<TypeId> outputs_type = {kNumberTypeInt32};
auto tensor_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
tensor_info_builder->SetOriginDataFormat(origin_data_format);
tensor_info_builder->SetOutputsFormat(outputs_format);
tensor_info_builder->SetOutputsDeviceType(outputs_type);
tensor_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE);
tensor_info_builder->SetProcessor(kernel::Processor::CUDA);
auto tensor_selected_info = tensor_info_builder->Build();
AnfAlgo::SetSelectKernelBuildInfo(tensor_selected_info, uniform_input[1].get());
// create new uniform_real_node
auto uniform_real_node = func_graph->NewCNode(uniform_input);
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_))));
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_))));
auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64);
uniform_real_node->set_abstract(uniform_abstract);
uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>());
SetNewKernelInfo(uniform_real_node);
// create new_node, has two input, first is cnode->input[1], second is unifom_real_node
AnfNodePtrList new_node_inputs = {NewValueNode(prim::kPrimGkDropout)};
new_node_inputs.push_back(cnode->input(1));
new_node_inputs.push_back(uniform_real_node);
auto new_node = func_graph->NewCNode(new_node_inputs);
AnfAlgo::GetCNodePrimitive(new_node)->set_attr("keep_prob", AnfAlgo::GetCNodePrimitive(cnode)->GetAttr("keep_prob"));
new_node->set_abstract(old_abstract);
new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
SetNewKernelInfo(new_node);
return new_node;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class SubstituteDropout : public PatternProcessPass {
public:
explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {}
~SubstituteDropout() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
static unsigned int seed_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_

@ -164,6 +164,9 @@ inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Pri
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad");
inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout");
inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal");
inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal");
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");

@ -118,7 +118,6 @@ class StandardLaplace(PrimitiveWithInfer):
return out
class Gamma(PrimitiveWithInfer):
r"""
Produces random positive floating-point values x, distributed according to probability density function:
@ -532,6 +531,7 @@ class Multinomial(PrimitiveWithInfer):
"value": None}
return out
class UniformCandidateSampler(PrimitiveWithInfer):
r"""
Uniform candidate sampler.

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, keep_prob):
super(Net, self).__init__()
self.drop = P.Dropout(keep_prob)
def construct(self, x_):
return self.drop(x_)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
x_shape = [4096, 768]
x = np.ones(x_shape).astype(np.float32)
keep_prob = 0.9
dropout = Net(keep_prob)
tx = Tensor(x)
output, mask = dropout(tx)
output_np = output.asnumpy()
elem_count = x.size
nonzero_count = np.count_nonzero(output_np)
assert (elem_count * (keep_prob - 0.1)) < nonzero_count < (elem_count * (keep_prob + 0.1))
output_sum = np.sum(output_np)
x_sum = np.sum(x)
assert abs(output_sum - x_sum)/x_sum < 0.1
# check mask
mask_np = mask.asnumpy()
mask_sum = np.sum(mask_np)
assert np.count_nonzero(mask_np) == nonzero_count
assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1
Loading…
Cancel
Save