support atomic clean and change package for akg.

pull/9089/head
tronzhang 4 years ago
parent 125940314f
commit 2190da9946

2
akg

@ -1 +1 @@
Subproject commit 6ffe9c24319d7297d0feeb10ee2bd8135e24c5c8
Subproject commit 0a0338fecd54c654c1992af156d41e036569343c

@ -37,7 +37,9 @@ def expand_gkdropout(expand_info):
keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat")
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat")
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v])
if input_mask.dtype != input_x.dtype:
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
# compute result

@ -16,7 +16,7 @@
from .model import PrimLib, Graph, Tensor
use_poly_reduce = False
use_poly_reduce = True
class GraphSplitByPattern:
"""Graph splitter"""

@ -204,6 +204,16 @@ class CompositeGraph:
def load(self, desc):
"""Load Graph from json"""
def _attr_of(op, inputs, output):
def _get_axis_while_none(input_shape, output_shape):
red_axis = []
if len(output_shape) == len(input_shape):
for s, i in enumerate(output_shape):
if s == 1 and input_shape[i] > 1:
red_axis.append(i)
else:
red_axis = list(range(len(output_shape)))
return red_axis
attr = {}
if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
return attr
@ -211,10 +221,7 @@ class CompositeGraph:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
assert len(output.shape) == len(inputs[0].shape)
for i in range(len(output.shape)):
if output.shape[i] == 1 and inputs[0].shape[i] > 1:
red_axis.append(i)
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]

@ -244,7 +244,11 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co
output_json[kJsonKeyFormat] = this->GetOutputFormat(anf_node, i);
output_json[kJsonKeyName] = output_name;
output_json[kJsonKeyTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
output_json[kJsonKeyShape] = this->GetOutputShape(anf_node, i);
auto output_shape = this->GetOutputShape(anf_node, i);
if (output_shape.empty()) {
output_shape.push_back(1);
}
output_json[kJsonKeyShape] = output_shape;
outputs_json->push_back(output_json);
}
return true;
@ -680,7 +684,11 @@ nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNod
GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second);
input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first);
input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first);
auto input_shape = this->GetInputShape(tmp_input.first, tmp_input.second.first);
if (input_shape.empty()) {
input_shape.push_back(1);
}
input_desc_json[kJsonKeyShape] = input_shape;
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
}
return inputs_json;

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_GPU_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_GPU_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
class AtomicCleanInsertter : public Pass {
public:
AtomicCleanInsertter() : Pass("atomic_clean") {}
~AtomicCleanInsertter() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng);
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index);
void AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node, const AnfNodePtr &post_node,
const FuncGraphManagerPtr &mng);
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
void CorrectAbstract(const AnfNodePtr &composite_node);
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng);
CNodePtr atomic_add_node_{nullptr};
size_t reduce_real_output_index_{0};
size_t real_output_num_{0};
};
using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_GPU_H_

@ -30,7 +30,9 @@ bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) {
auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
if (main_primitive != nullptr && node_primitive != nullptr) {
if (main_primitive->name() != node_primitive->name()) {
// Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
// alone can prevent some redundant output case (input -> reshape -> output).
if (main_primitive->name() != node_primitive->name() || IsPrimitiveCNode(node, prim::kPrimReshape)) {
return false;
}

@ -908,5 +908,126 @@ void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNo
depend_prior->insert(item);
}
}
std::string GetFormat(const AnfNodePtr &node) {
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto kernel_build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(kernel_build_info);
return kernel_build_info->GetOutputFormat(0);
}
TypePtr GetType(const AnfNodePtr &node) {
const auto &abstract = node->abstract();
auto type = abstract->BuildType();
MS_EXCEPTION_IF_NULL(type);
return type;
}
ShapeVector GetShape(const AnfNodePtr &node) {
auto abstract = node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto shape = abstract->GetShapeTrack();
if (shape == nullptr || !shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Cannot get shape from " << node->fullname_with_scope();
}
return shape->cast<abstract::ShapePtr>()->shape();
}
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
const auto &attrs = prim->attrs();
auto iter = attrs.find("axis");
if (iter == attrs.end()) {
MS_LOG(EXCEPTION) << "Origin node have no attributes!";
}
std::vector<int64_t> axis;
auto &v = iter->second;
if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
for (auto value : vec) {
if (value->isa<Int64Imm>()) {
axis.push_back(GetValue<int64_t>(value));
} else {
MS_LOG(EXCEPTION) << "Reduce axis type should be int64!";
}
}
} else if (v->isa<Int64Imm>()) {
axis.push_back(GetValue<int64_t>(v));
} else {
MS_LOG(EXCEPTION) << "Reduce axis should be a list or tuple!";
}
return axis;
}
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) {
// Limitation: 1. Node's attributes should be set out of this function; 2. only one output.
MS_EXCEPTION_IF_NULL(out_info.type);
auto out_type = out_info.type;
if (auto otype = out_info.type->cast<TensorTypePtr>(); otype != nullptr) {
out_type = otype->element();
}
// Create CNode.
auto cnode = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode);
// Setup abstract.
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, out_info.shape);
cnode->set_abstract(abs_tensor);
// Setup kernel info.
auto kernel_info = std::make_shared<device::KernelInfo>();
cnode->set_kernel_info(kernel_info);
std::vector<size_t> feature_map_input_indexs;
kernel_info->set_feature_map_flag(false);
for (size_t i = 1; i < inputs.size(); ++i) {
if (AnfAlgo::IsFeatureMapOutput(inputs[i])) {
kernel_info->set_feature_map_flag(true);
feature_map_input_indexs.push_back(i);
}
}
if (inputs.size() == 1) {
kernel_info->set_feature_map_flag(true);
}
if (AnfAlgo::IsRealKernel(cnode)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}
// Setup kernel build info.
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
for (size_t i = 1; i < inputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
input_formats.push_back(input_format);
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
input_types.push_back(input_type);
}
std::vector<std::string> output_formats = {out_info.format};
std::vector<TypeId> output_types = {out_type->type_id()};
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
info_builder.SetInputsFormat(input_formats);
info_builder.SetInputsDeviceType(input_types);
info_builder.SetOutputsFormat(output_formats);
info_builder.SetOutputsDeviceType(output_types);
info_builder.SetProcessor(kernel::Processor::CUDA);
info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto selected_info = info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(selected_info, cnode.get());
func_graph->AddNode(cnode);
return cnode;
}
} // namespace opt
} // namespace mindspore

@ -27,6 +27,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include <nlohmann/json.hpp>
@ -38,6 +39,8 @@ inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropou
namespace opt {
using kernel::DumpOption;
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel";
constexpr auto kGraphKernelSplitFunc = "split_with_json";
constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
@ -45,6 +48,12 @@ constexpr auto kJsonKeyMultiGraph = "multi_graph";
constexpr auto kJsonKeyGraphDesc = "graph_desc";
constexpr auto kJsonKeyGraphMode = "graph_mode";
struct DataInfo {
std::string format{kOpFormat_DEFAULT};
ShapeVector shape{1};
TypePtr type{nullptr};
};
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
AnfNodePtrList *src_outputs = nullptr);
@ -74,6 +83,49 @@ void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, Anf
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
std::string GetFormat(const AnfNodePtr &node);
TypePtr GetType(const AnfNodePtr &node);
ShapeVector GetShape(const AnfNodePtr &node);
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
// Create tensor value.
if (info.shape.size() != 1 && info.shape[0] != 1) {
MS_LOG(EXCEPTION) << "Only support create scalar tensor value node!!!";
}
if (info.type == nullptr) {
MS_LOG(EXCEPTION) << "Data type is needed!!!";
}
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(info.type->type_id(), info.shape);
MS_EXCEPTION_IF_NULL(tensor);
tensor::DeviceInfo device_info{info.format, info.type};
tensor->set_device_info(device_info);
auto data_ptr = tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), &value, data_length);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into scalar tensor.";
}
// Create value node.
ValueNodePtr new_value_node = std::make_shared<ValueNode>(tensor);
new_value_node->set_abstract(tensor->ToAbstract());
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{info.format});
std::vector<TypeId> types = {info.type->type_id()};
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_

@ -35,6 +35,7 @@
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
#include "backend/optimizer/gpu/reduce_precision_fusion.h"
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
@ -176,6 +177,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>());
pm->AddPass(std::make_shared<opt::BindValueToGraph>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);

@ -0,0 +1,124 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
class SumOutNet(Cell):
def __init__(self):
super(SumOutNet, self).__init__()
self.square = P.Square()
self.sum = P.ReduceSum()
def construct(self, x):
mul_res = self.square(x)
return self.sum(mul_res, (0,))
class SingleOutNet(Cell):
def __init__(self):
super(SingleOutNet, self).__init__()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sum = P.ReduceSum()
def construct(self, x, y):
mul_res = self.mul(x, y)
sum_res = self.sum(mul_res, ())
return self.add(sum_res, x)
class MultiOutNet(Cell):
def __init__(self):
super(MultiOutNet, self).__init__()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sum = P.ReduceSum()
def construct(self, x, y):
add_res = self.add(x, y)
mul_res = self.mul(add_res, add_res)
sum_res = self.sum(mul_res, ())
return self.add(add_res, sum_res)
def atomic_add_sum_output():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
expect = np.sum(np.square(input_x), axis=(0,))
net = SumOutNet()
result = net(Tensor(input_x))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
def atomic_add_single_output():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 2, 2, 256]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 2, 2, 256]).astype(np.float32)
expect = np.sum(input_x * input_y) + input_x
net = SingleOutNet()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
def atomic_add_multi_output():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 2, 2, 256]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 2, 2, 256]).astype(np.float32)
expect = np.sum(np.square(input_x + input_y)) + (input_x + input_y)
net = MultiOutNet()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atomic_add_sum_output_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
atomic_add_sum_output()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atomic_add_single_output_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
atomic_add_single_output()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atomic_add_multi_output_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
atomic_add_multi_output()
Loading…
Cancel
Save