parent
08dc1481c7
commit
3ef0e9f053
@ -0,0 +1,49 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for GkDropOut"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def expand_gkdropout(expand_info):
|
||||
"""GkDropOut expander"""
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
maks_desc = expand_info['input_desc'][1]
|
||||
keep_prob = None
|
||||
for attr in expand_info['attr']:
|
||||
if 'keep_prob' in attr:
|
||||
keep_prob = attr['keep_prob']
|
||||
if keep_prob is None:
|
||||
raise RuntimeError("keep_prob does not exist in attrs.")
|
||||
# generate a graph.
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format'])
|
||||
graph_scope.set_input(input_x, input_mask)
|
||||
keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat")
|
||||
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat")
|
||||
|
||||
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v])
|
||||
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
|
||||
|
||||
# compute result
|
||||
result = graph_builder.emit('Mul', [r_keep_prob, input_x])
|
||||
result = graph_builder.emit('Mul', [result, mask])
|
||||
# set graph output.
|
||||
graph_scope.set_output(result, mask)
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
@ -0,0 +1,120 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
unsigned int SubstituteDropout::seed_ = time(NULL);
|
||||
|
||||
const BaseRef SubstituteDropout::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimDropout, Xs});
|
||||
}
|
||||
|
||||
void SetNewKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
std::string origin_data_format = kOpFormat_DEFAULT;
|
||||
auto cnode_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
cnode_info_builder->SetOriginDataFormat(origin_data_format);
|
||||
cnode_info_builder->SetInputsFormat(inputs_format);
|
||||
cnode_info_builder->SetInputsDeviceType(inputs_type);
|
||||
cnode_info_builder->SetOutputsFormat(outputs_format);
|
||||
cnode_info_builder->SetOutputsDeviceType(outputs_type);
|
||||
cnode_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
cnode_info_builder->SetProcessor(kernel::Processor::CUDA);
|
||||
auto cnode_selected_info = cnode_info_builder->Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get());
|
||||
}
|
||||
|
||||
const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() < kDropoutInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Dropout's input num is wrong";
|
||||
}
|
||||
AbstractBasePtr old_abstract = cnode->abstract()->Clone();
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
|
||||
ShapeVector shape_i64;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); });
|
||||
|
||||
// Create new tensor
|
||||
AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)};
|
||||
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())),
|
||||
static_cast<void *>(&shape[0]), kNumberTypeInt64);
|
||||
uniform_input.push_back(NewValueNode(tensor));
|
||||
uniform_input[1]->set_abstract(tensor->ToAbstract());
|
||||
uniform_input[1]->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
std::string origin_data_format = kOpFormat_DEFAULT;
|
||||
std::vector<std::string> outputs_format = {origin_data_format};
|
||||
std::vector<TypeId> outputs_type = {kNumberTypeInt32};
|
||||
auto tensor_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
tensor_info_builder->SetOriginDataFormat(origin_data_format);
|
||||
tensor_info_builder->SetOutputsFormat(outputs_format);
|
||||
tensor_info_builder->SetOutputsDeviceType(outputs_type);
|
||||
tensor_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
tensor_info_builder->SetProcessor(kernel::Processor::CUDA);
|
||||
auto tensor_selected_info = tensor_info_builder->Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(tensor_selected_info, uniform_input[1].get());
|
||||
|
||||
// create new uniform_real_node
|
||||
auto uniform_real_node = func_graph->NewCNode(uniform_input);
|
||||
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_))));
|
||||
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_))));
|
||||
auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64);
|
||||
uniform_real_node->set_abstract(uniform_abstract);
|
||||
uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
SetNewKernelInfo(uniform_real_node);
|
||||
|
||||
// create new_node, has two input, first is cnode->input[1], second is unifom_real_node
|
||||
AnfNodePtrList new_node_inputs = {NewValueNode(prim::kPrimGkDropout)};
|
||||
new_node_inputs.push_back(cnode->input(1));
|
||||
new_node_inputs.push_back(uniform_real_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
AnfAlgo::GetCNodePrimitive(new_node)->set_attr("keep_prob", AnfAlgo::GetCNodePrimitive(cnode)->GetAttr("keep_prob"));
|
||||
new_node->set_abstract(old_abstract);
|
||||
new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
SetNewKernelInfo(new_node);
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SubstituteDropout : public PatternProcessPass {
|
||||
public:
|
||||
explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {}
|
||||
~SubstituteDropout() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
static unsigned int seed_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, keep_prob):
|
||||
super(Net, self).__init__()
|
||||
self.drop = P.Dropout(keep_prob)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.drop(x_)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dropout():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
x_shape = [4096, 768]
|
||||
x = np.ones(x_shape).astype(np.float32)
|
||||
keep_prob = 0.9
|
||||
dropout = Net(keep_prob)
|
||||
tx = Tensor(x)
|
||||
output, mask = dropout(tx)
|
||||
|
||||
output_np = output.asnumpy()
|
||||
elem_count = x.size
|
||||
nonzero_count = np.count_nonzero(output_np)
|
||||
assert (elem_count * (keep_prob - 0.1)) < nonzero_count < (elem_count * (keep_prob + 0.1))
|
||||
output_sum = np.sum(output_np)
|
||||
x_sum = np.sum(x)
|
||||
assert abs(output_sum - x_sum)/x_sum < 0.1
|
||||
# check mask
|
||||
mask_np = mask.asnumpy()
|
||||
mask_sum = np.sum(mask_np)
|
||||
assert np.count_nonzero(mask_np) == nonzero_count
|
||||
assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1
|
Loading…
Reference in new issue