!11930 【GraphKernel】Replace Assign with InplaceAssign
From: @dayschan Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @gaoxiong1pull/11930/MERGE
commit
0ff27ef3b4
@ -0,0 +1,244 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/graph_kernel/optimize_assign.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
|
||||||
|
const std::vector<TypeId> &inputs_type,
|
||||||
|
const std::vector<std::string> &output_formats,
|
||||||
|
const std::vector<TypeId> &output_types, const CNodePtr &cnode) {
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||||
|
graph_info_builder.SetInputsFormat(inputs_format);
|
||||||
|
graph_info_builder.SetInputsDeviceType(inputs_type);
|
||||||
|
graph_info_builder.SetOutputsFormat(output_formats);
|
||||||
|
graph_info_builder.SetOutputsDeviceType(output_types);
|
||||||
|
graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
|
||||||
|
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||||
|
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||||
|
return graph_info_builder.Build();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If an Assign's source node was outputted with this Assign, the src-node should be removed from output list,
|
||||||
|
* external users can use the dest-node under the premise of correct execution order.
|
||||||
|
* This function find out the [index of src node in output list] and [external dest-node].
|
||||||
|
* Note:
|
||||||
|
* 1. Assign is always in output list. (links to external Depend node)
|
||||||
|
* 2. Assign's dest-node should be a Parameter.
|
||||||
|
*/
|
||||||
|
std::map<size_t, AnfNodePtr> FindAssignAndOutputVal(const CNodePtr &fg_cnode) {
|
||||||
|
// Check output includes assign
|
||||||
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(fg_cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto out_cnode = func_graph->output()->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||||
|
std::map<size_t, AnfNodePtr> output_replace_map;
|
||||||
|
|
||||||
|
if (!IsPrimitiveCNode(out_cnode, prim::kPrimMakeTuple)) {
|
||||||
|
return output_replace_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trans parameter to the real input
|
||||||
|
auto ParameterToInput = [&func_graph, &fg_cnode](const AnfNodePtr &p) {
|
||||||
|
const auto ¶ms = func_graph->parameters();
|
||||||
|
size_t i = std::find(params.begin(), params.end(), p) - params.begin();
|
||||||
|
return i == params.size() ? nullptr : fg_cnode->input(i + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto &inputs = out_cnode->inputs();
|
||||||
|
for (const auto &out : inputs) {
|
||||||
|
if (IsPrimitiveCNode(out, prim::kPrimAssign)) {
|
||||||
|
auto assign_val = out->cast<CNodePtr>()->input(2);
|
||||||
|
auto assign_parameter = out->cast<CNodePtr>()->input(1);
|
||||||
|
auto iter = std::find(inputs.begin() + 1, inputs.end(), assign_val);
|
||||||
|
if (iter != inputs.end()) {
|
||||||
|
size_t assign_val_index = iter - inputs.begin();
|
||||||
|
auto assign_to = ParameterToInput(assign_parameter);
|
||||||
|
if (assign_to != nullptr && assign_val_index > 0) {
|
||||||
|
output_replace_map[assign_val_index - 1] = assign_to;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output_replace_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr ¶m_user) {
|
||||||
|
auto mng = AnfAlgo::GetCNodeFuncGraphPtr(gk_node)->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
bool result = false;
|
||||||
|
auto IncludeUser = [&result, &gk_node](const AnfNodePtr &node) {
|
||||||
|
if (node == gk_node) {
|
||||||
|
result = true;
|
||||||
|
return EXCLUDE;
|
||||||
|
}
|
||||||
|
return result ? EXCLUDE : FOLLOW;
|
||||||
|
};
|
||||||
|
static_cast<void>(DeepLinkedGraphSearch(param_user, IncludeUser));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr AddControlDepend(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr ¶m_user) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), getitem, param_user};
|
||||||
|
auto cd_node = func_graph->NewCNode(cd_inputs);
|
||||||
|
func_graph->AddNode(cd_node);
|
||||||
|
return cd_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LinkControlDepends(const FuncGraphPtr &func_graph, const AnfNodePtrList &cd_nodes) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
auto output_tuple = func_graph->output()->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(output_tuple);
|
||||||
|
auto cur_node = output_tuple->input(1);
|
||||||
|
for (const auto &cd : cd_nodes) {
|
||||||
|
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), cur_node, cd};
|
||||||
|
auto depend_node = func_graph->NewCNode(depend_inputs);
|
||||||
|
depend_node->set_abstract(depend_inputs[1]->abstract());
|
||||||
|
cur_node = depend_node;
|
||||||
|
}
|
||||||
|
mng->Replace(output_tuple->input(1), cur_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t GetitemIndex(const AnfNodePtr &getitem) {
|
||||||
|
auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
|
||||||
|
auto value_ptr = GetValueNode(index_node);
|
||||||
|
return GetValue<int64_t>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode,
|
||||||
|
const AnfNodePtr &assign_to, int64_t removed_index) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
AnfNodePtrList cd_nodes;
|
||||||
|
for (const auto &getitem_iter : mng->node_users()[cnode]) {
|
||||||
|
auto getitem = getitem_iter.first;
|
||||||
|
if (GetitemIndex(getitem) != removed_index) continue;
|
||||||
|
auto getitem_users = mng->node_users()[getitem]; // get a copy of getitem's users before replacing
|
||||||
|
mng->Replace(getitem, assign_to);
|
||||||
|
|
||||||
|
for (const auto &getitem_user_iter : getitem_users) {
|
||||||
|
auto getitem_user = getitem_user_iter.first;
|
||||||
|
// 1. A previous pass `DependFormater` has ensured that all data users are directly link to its
|
||||||
|
// input, without Depend node.
|
||||||
|
// 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add a ControlDepend.
|
||||||
|
if (!AnfAlgo::IsRealKernel(getitem_user) || HasPathToParamUser(cnode, getitem_user)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// keep execution order: cnode -> getitem_user
|
||||||
|
auto cd_node = AddControlDepend(func_graph, getitem, getitem_user);
|
||||||
|
cd_nodes.push_back(cd_node);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return cd_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
||||||
|
bool changed = false;
|
||||||
|
AnfNodePtrList control_depend_nodes;
|
||||||
|
for (const auto &n : todos) {
|
||||||
|
if (!AnfAlgo::IsGraphKernel(n)) continue;
|
||||||
|
auto cnode = n->cast<CNodePtr>();
|
||||||
|
auto replaceable_nodes = FindAssignAndOutputVal(cnode);
|
||||||
|
if (replaceable_nodes.empty()) continue;
|
||||||
|
changed = true;
|
||||||
|
for (const auto &iter : replaceable_nodes) {
|
||||||
|
auto cd_nodes = UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first);
|
||||||
|
control_depend_nodes.insert(control_depend_nodes.end(), cd_nodes.begin(), cd_nodes.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LinkControlDepends(func_graph, control_depend_nodes);
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReplaceAssignByInplaceAssignInGraphkernel(const FuncGraphPtr &func_graph) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
bool changed = false;
|
||||||
|
for (const auto &n : todos) {
|
||||||
|
if (!AnfAlgo::CheckPrimitiveType(n, prim::kPrimAssign)) continue;
|
||||||
|
changed = true;
|
||||||
|
auto cnode = n->cast<CNodePtr>();
|
||||||
|
AnfNodePtrList inputs = {NewValueNode(prim::kPrimInplaceAssign->Clone()), cnode->input(1), cnode->input(2),
|
||||||
|
cnode->input(2)};
|
||||||
|
auto new_cnode = func_graph->NewCNode(inputs);
|
||||||
|
AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), new_cnode);
|
||||||
|
new_cnode->set_abstract(inputs.back()->abstract());
|
||||||
|
new_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(cnode);
|
||||||
|
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
|
||||||
|
input_formats.push_back(input_formats.back());
|
||||||
|
input_types.push_back(input_types.back());
|
||||||
|
std::vector<std::string> output_formats = {input_formats.back()};
|
||||||
|
std::vector<TypeId> output_types = {input_types.back()};
|
||||||
|
auto graph_sel_info = BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, cnode);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, new_cnode.get());
|
||||||
|
mng->Replace(cnode, new_cnode);
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RepalceAssignByInplaceAssign(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
|
||||||
|
auto changed = false;
|
||||||
|
for (const auto &n : todos) {
|
||||||
|
if (!AnfAlgo::IsGraphKernel(n)) continue;
|
||||||
|
auto graph_kernel_fg = AnfAlgo::GetCNodeFuncGraphPtr(n);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
|
||||||
|
changed = ReplaceAssignByInplaceAssignInGraphkernel(graph_kernel_fg) || changed;
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool OptimizeAssign::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(func_graph, true);
|
||||||
|
func_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
auto res = RepalceOutputByParameter(func_graph);
|
||||||
|
if (res) {
|
||||||
|
mng->RemoveRoots();
|
||||||
|
mng->KeepRoots({func_graph});
|
||||||
|
}
|
||||||
|
return RepalceAssignByInplaceAssign(func_graph);
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,34 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_ASSIGN_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_ASSIGN_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "backend/session/kernel_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class OptimizeAssign : public Pass {
|
||||||
|
public:
|
||||||
|
OptimizeAssign() : Pass("optimize_assign") {}
|
||||||
|
~OptimizeAssign() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
};
|
||||||
|
using OptimizeAssignPtr = std::shared_ptr<OptimizeAssign>;
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_ASSIGN_H_
|
Loading…
Reference in new issue