!11930 【GraphKernel】Replace Assign with InplaceAssign

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @gaoxiong1
pull/11930/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0ff27ef3b4

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,6 +21,12 @@ from mindspore import log as logger
from . import model from . import model
def reset_graphmode_for_inplaceassign(graph_list, graph_mode):
for i, g in enumerate(graph_list):
if any([op['name'] == 'InplaceAssign' for op in g['op_desc']]):
graph_mode[i] = 'composite'
def split_with_json(json_str: str): def split_with_json(json_str: str):
"""Call costmodel to split GraphKernel""" """Call costmodel to split GraphKernel"""
try: try:
@ -30,6 +36,7 @@ def split_with_json(json_str: str):
graph_split, graph_mode = model.split(comp.graph, target) graph_split, graph_mode = model.split(comp.graph, target)
is_multi_graph = len(graph_split) > 1 is_multi_graph = len(graph_split) > 1
graph_list = list(map(comp.dump, graph_split)) graph_list = list(map(comp.dump, graph_split))
reset_graphmode_for_inplaceassign(graph_list, graph_mode)
result = {"multi_graph": is_multi_graph, result = {"multi_graph": is_multi_graph,
"graph_desc": graph_list, "graph_desc": graph_list,
"graph_mode": graph_mode} "graph_mode": graph_mode}

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -95,6 +95,12 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
return result; return result;
} }
bool IsSideEffectNode(const AnfNodePtr &node) {
std::vector<PrimitivePtr> side_effect_nodes = {prim::kPrimAssign, prim::kPrimInplaceAssign};
return std::any_of(side_effect_nodes.begin(), side_effect_nodes.end(),
[&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
}
/* Unify the repeated output in a func_graph. /* Unify the repeated output in a func_graph.
* %1 = call @graph_kernel(p1, p2) * %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) * %2 = tuple_getitem(%1, 0)
@ -318,6 +324,7 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
return changed; return changed;
} }
// update the GetItem(node, i) to GetItem(node, i - offset)
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) { void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
if (offset == 0) return; if (offset == 0) return;
MS_EXCEPTION_IF_NULL(getitem); MS_EXCEPTION_IF_NULL(getitem);
@ -338,14 +345,17 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
size_t offset = 0; size_t offset = 0;
for (size_t i = 0; i < getitems.size(); ++i) { for (size_t i = 0; i < getitems.size(); ++i) {
if (getitems[i] == nullptr) { // If a node has no user, it should be eliminated, but except for side-effect node.
if (getitems[i] == nullptr && !IsSideEffectNode(old_maketuple->input(i + 1))) {
offset++; offset++;
} else { } else {
new_maketuple_inputs.push_back(old_maketuple->input(i + 1)); new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
abstract_list.push_back(old_maketuple->input(i + 1)->abstract()); abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
if (getitems[i] != nullptr) {
UpdateGetitemIndex(getitems[i], offset); UpdateGetitemIndex(getitems[i], offset);
} }
} }
}
if (offset == 0) return nullptr; if (offset == 0) return nullptr;
if (new_maketuple_inputs.size() == 1) { if (new_maketuple_inputs.size() == 1) {
MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty"; MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";

@ -0,0 +1,244 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/optimize_assign.h"
#include <algorithm>
#include <vector>
#include <string>
#include <map>
#include "base/core_ops.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
const std::vector<TypeId> &inputs_type,
const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types, const CNodePtr &cnode) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(inputs_format);
graph_info_builder.SetInputsDeviceType(inputs_type);
graph_info_builder.SetOutputsFormat(output_formats);
graph_info_builder.SetOutputsDeviceType(output_types);
graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
return graph_info_builder.Build();
}
/**
* If an Assign's source node was outputted with this Assign, the src-node should be removed from output list,
* external users can use the dest-node under the premise of correct execution order.
* This function find out the [index of src node in output list] and [external dest-node].
* Note:
* 1. Assign is always in output list. (links to external Depend node)
* 2. Assign's dest-node should be a Parameter.
*/
std::map<size_t, AnfNodePtr> FindAssignAndOutputVal(const CNodePtr &fg_cnode) {
// Check output includes assign
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(fg_cnode);
MS_EXCEPTION_IF_NULL(func_graph);
auto out_cnode = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_cnode);
std::map<size_t, AnfNodePtr> output_replace_map;
if (!IsPrimitiveCNode(out_cnode, prim::kPrimMakeTuple)) {
return output_replace_map;
}
// Trans parameter to the real input
auto ParameterToInput = [&func_graph, &fg_cnode](const AnfNodePtr &p) {
const auto &params = func_graph->parameters();
size_t i = std::find(params.begin(), params.end(), p) - params.begin();
return i == params.size() ? nullptr : fg_cnode->input(i + 1);
};
const auto &inputs = out_cnode->inputs();
for (const auto &out : inputs) {
if (IsPrimitiveCNode(out, prim::kPrimAssign)) {
auto assign_val = out->cast<CNodePtr>()->input(2);
auto assign_parameter = out->cast<CNodePtr>()->input(1);
auto iter = std::find(inputs.begin() + 1, inputs.end(), assign_val);
if (iter != inputs.end()) {
size_t assign_val_index = iter - inputs.begin();
auto assign_to = ParameterToInput(assign_parameter);
if (assign_to != nullptr && assign_val_index > 0) {
output_replace_map[assign_val_index - 1] = assign_to;
}
}
}
}
return output_replace_map;
}
bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr &param_user) {
auto mng = AnfAlgo::GetCNodeFuncGraphPtr(gk_node)->manager();
MS_EXCEPTION_IF_NULL(mng);
bool result = false;
auto IncludeUser = [&result, &gk_node](const AnfNodePtr &node) {
if (node == gk_node) {
result = true;
return EXCLUDE;
}
return result ? EXCLUDE : FOLLOW;
};
static_cast<void>(DeepLinkedGraphSearch(param_user, IncludeUser));
return result;
}
AnfNodePtr AddControlDepend(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr &param_user) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), getitem, param_user};
auto cd_node = func_graph->NewCNode(cd_inputs);
func_graph->AddNode(cd_node);
return cd_node;
}
void LinkControlDepends(const FuncGraphPtr &func_graph, const AnfNodePtrList &cd_nodes) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto output_tuple = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_tuple);
auto cur_node = output_tuple->input(1);
for (const auto &cd : cd_nodes) {
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), cur_node, cd};
auto depend_node = func_graph->NewCNode(depend_inputs);
depend_node->set_abstract(depend_inputs[1]->abstract());
cur_node = depend_node;
}
mng->Replace(output_tuple->input(1), cur_node);
}
int64_t GetitemIndex(const AnfNodePtr &getitem) {
auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
auto value_ptr = GetValueNode(index_node);
return GetValue<int64_t>(value_ptr);
}
AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode,
const AnfNodePtr &assign_to, int64_t removed_index) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
AnfNodePtrList cd_nodes;
for (const auto &getitem_iter : mng->node_users()[cnode]) {
auto getitem = getitem_iter.first;
if (GetitemIndex(getitem) != removed_index) continue;
auto getitem_users = mng->node_users()[getitem]; // get a copy of getitem's users before replacing
mng->Replace(getitem, assign_to);
for (const auto &getitem_user_iter : getitem_users) {
auto getitem_user = getitem_user_iter.first;
// 1. A previous pass `DependFormater` has ensured that all data users are directly link to its
// input, without Depend node.
// 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add a ControlDepend.
if (!AnfAlgo::IsRealKernel(getitem_user) || HasPathToParamUser(cnode, getitem_user)) {
continue;
}
// keep execution order: cnode -> getitem_user
auto cd_node = AddControlDepend(func_graph, getitem, getitem_user);
cd_nodes.push_back(cd_node);
}
break;
}
return cd_nodes;
}
bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
auto todos = TopoSort(func_graph->get_return());
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
AnfNodePtrList control_depend_nodes;
for (const auto &n : todos) {
if (!AnfAlgo::IsGraphKernel(n)) continue;
auto cnode = n->cast<CNodePtr>();
auto replaceable_nodes = FindAssignAndOutputVal(cnode);
if (replaceable_nodes.empty()) continue;
changed = true;
for (const auto &iter : replaceable_nodes) {
auto cd_nodes = UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first);
control_depend_nodes.insert(control_depend_nodes.end(), cd_nodes.begin(), cd_nodes.end());
}
}
LinkControlDepends(func_graph, control_depend_nodes);
return changed;
}
bool ReplaceAssignByInplaceAssignInGraphkernel(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
for (const auto &n : todos) {
if (!AnfAlgo::CheckPrimitiveType(n, prim::kPrimAssign)) continue;
changed = true;
auto cnode = n->cast<CNodePtr>();
AnfNodePtrList inputs = {NewValueNode(prim::kPrimInplaceAssign->Clone()), cnode->input(1), cnode->input(2),
cnode->input(2)};
auto new_cnode = func_graph->NewCNode(inputs);
AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), new_cnode);
new_cnode->set_abstract(inputs.back()->abstract());
new_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(cnode);
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
input_formats.push_back(input_formats.back());
input_types.push_back(input_types.back());
std::vector<std::string> output_formats = {input_formats.back()};
std::vector<TypeId> output_types = {input_types.back()};
auto graph_sel_info = BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, cnode);
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, new_cnode.get());
mng->Replace(cnode, new_cnode);
}
return changed;
}
bool RepalceAssignByInplaceAssign(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = TopoSort(func_graph->get_return());
auto changed = false;
for (const auto &n : todos) {
if (!AnfAlgo::IsGraphKernel(n)) continue;
auto graph_kernel_fg = AnfAlgo::GetCNodeFuncGraphPtr(n);
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
changed = ReplaceAssignByInplaceAssignInGraphkernel(graph_kernel_fg) || changed;
}
return changed;
}
} // namespace
bool OptimizeAssign::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto res = RepalceOutputByParameter(func_graph);
if (res) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return RepalceAssignByInplaceAssign(func_graph);
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_ASSIGN_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_ASSIGN_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
class OptimizeAssign : public Pass {
public:
OptimizeAssign() : Pass("optimize_assign") {}
~OptimizeAssign() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using OptimizeAssignPtr = std::shared_ptr<OptimizeAssign>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_ASSIGN_H_

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -413,6 +413,36 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePt
return format; return format;
} }
std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto types = build_info->GetAllInputDeviceTypes();
return types;
}
std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto types = build_info->GetAllOutputDeviceTypes();
return types;
}
std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) { std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) { if (!AnfAlgo::IsRealKernel(node)) {

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -105,6 +105,10 @@ class AnfRuntimeAlgorithm {
static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node);
// get all inputs format select of anf node // get all inputs format select of anf node
static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node); static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node);
// get all inputs type select of anf node
static std::vector<TypeId> GetAllInputDeviceTypes(const AnfNodePtr &node);
// get all outputs type select of anf node
static std::vector<TypeId> GetAllOutputDeviceTypes(const AnfNodePtr &node);
// get origin data format select of anf node // get origin data format select of anf node
static std::string GetOriginDataFormat(const AnfNodePtr &node); static std::string GetOriginDataFormat(const AnfNodePtr &node);
// get output format select of anf node // get output format select of anf node

@ -56,6 +56,7 @@
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h" #include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include "backend/optimizer/graph_kernel/value_graph_binder.h" #include "backend/optimizer/graph_kernel/value_graph_binder.h"
#include "backend/optimizer/graph_kernel/parallel_fusion.h" #include "backend/optimizer/graph_kernel/parallel_fusion.h"
#include "backend/optimizer/graph_kernel/optimize_assign.h"
#include "backend/optimizer/pass/communication_op_fusion.h" #include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/getitem_tuple.h"
#include "common/trans.h" #include "common/trans.h"
@ -188,6 +189,8 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>()); pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::OptimizeAssign>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>()); pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());

@ -379,6 +379,9 @@ inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg"); inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
// GraphKernel ops
inline const PrimitivePtr kPrimInplaceAssign = std::make_shared<Primitive>("InplaceAssign");
class DoSignaturePrimitive : public Primitive { class DoSignaturePrimitive : public Primitive {
public: public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)

Loading…
Cancel
Save