Bugfix about execution-order after GraphKernelSplitter

1. Removed the deprecated pass "EliminateGetitemForControlDepend"
2. Spread the MakeTuple in UpdateState's input at PreProcess, so that all inputs are directly connected
   to UpdateState, I dont need to consider the condition "Getitem-MakeTuple-UpdateState'.
   after this pass, the UpdateState(U, make_tuple(op1, op2, ...)) was changed to UpdateState(U, op1, op2, ...)
3. Shrink the UpdateState's inputs at PostProcess. The reverse operation of the above pass.
   recovered the UpdateState's format for the process after GraphKernel.
4. Add a pass ExtendOutputForUpdateState, it's the main job of this commit.
   Consider this situation:
   A Cast op has multiple users in a composite kernel, while it's also in the output list and connects to
     an external UpdateState. In the pass "ShapeOpsSplitter", it will be duplicated. after that, only one replica will be connected
     to the external UpdateState, others will be connected to its original users respectively.
   After the pass "GraphKernelSplitter", only one part will be connected to this UpdateState, the execution order of other nodes cannot be ensured.
   This pass extended the node that connects to UpdateState, if a node has an external UpdateState user, all outputs that depend on this node
     will be connected to this UpdateState. It may add many redundant edges, the next pass will handle it.
5. Add a pass MergeOutputForUpdateState after GraphKernelSplitter.
   if an UpdateState has multiple inputs from the same node, only one edge will be kept.
pull/12926/head
dayschan 4 years ago
parent c432105d8b
commit 49f78d5424

@ -28,6 +28,7 @@
#include "debug/anf_ir_dump.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
namespace mindspore {
namespace opt {
@ -50,9 +51,10 @@ void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
}
} // namespace
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false) {
bool merge_repeated_getitem) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(getitem_list);
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -194,121 +196,6 @@ class UnifyRepeatedGetitem : public Pass {
}
};
/* Merge the get_item nodes that have same index.
* subgraph graph_kernel(%para1, %para2)
* %1 = TensorAdd(%para1, %para2)
* %2 = Neg(%1)
* %3 = make_tuple(%1, %2)
* return (%3)
* %1 = call @graph_kernel(%p1, %p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = ControlDepend(%0, %2)
* %5 = other_user(%3)
* --->
* subgraph graph_kernel(%para1, %para2)
* %1 = TensorAdd(%para1, %para2)
* %2 = Neg(%1)
* %3 = make_tuple(%1, %2)
* return (%3)
* %1 = call @graph_kernel(%p1, %p2)
* %3 = tuple_getitem(%1, 1)
* %4 = ControlDepend(%0, %3)
* %5 = other_user(%3)
*
* Then the output 0 can be eliminate in the later pass.
*/
class EliminateGetitemForControlDepend : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) {
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
for (const auto &node : todos) {
getitems_.clear();
GetGraphKernelGetitemList(mng, node, &getitems_, false);
if (getitems_.empty()) continue;
indexes_.clear();
GetIndexesToControlDepend(mng);
FilterRedundantOutputs(node);
if (indexes_.empty()) continue;
size_t index = GetFinalIndex(node);
changed = ReplaceGetitems(mng, index) || changed;
}
return changed;
}
private:
AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs.
std::vector<size_t> indexes_; // Indexes of MakeTuple to be eliminated.
bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) {
MS_EXCEPTION_IF_NULL(getitems_[index]);
bool changed = false;
for (auto i : indexes_) {
if (i != index) {
MS_EXCEPTION_IF_NULL(getitems_[i]);
mng->Replace(getitems_[i], getitems_[index]);
changed = true;
}
}
return changed;
}
// Find the redundant output index.
// the real output should have multiple users.
void FilterRedundantOutputs(const AnfNodePtr &node) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto &users = mng->node_users();
auto maketuple = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
std::vector<size_t> result;
for (auto i : indexes_) {
auto real_output = maketuple->input(i + 1);
if (users[real_output].size() > 1) {
result.push_back(i);
}
}
indexes_ = std::move(result);
}
// Get the nodes that only have ControlDepend users.
void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) {
for (size_t i = 0; i < getitems_.size(); ++i) {
const AnfNodePtr &getitem = getitems_[i];
if (getitem == nullptr) {
continue;
}
const auto &getitem_user = mng->node_users()[getitem];
if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimControlDepend);
})) {
indexes_.push_back(i);
}
}
}
size_t GetFinalIndex(const AnfNodePtr &node) {
auto is_redundant_index = [this](size_t i) {
return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end();
};
for (size_t i = 0; i < getitems_.size(); ++i) {
if (getitems_[i] != nullptr && !is_redundant_index(i)) {
return i;
}
}
return indexes_[0];
}
};
} // namespace
// Remove the output without user or with virtual user (like ControlDepend)
bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
@ -319,13 +206,11 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
changed = Process(func_graph) || changed;
changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
return changed;
}
// update the GetItem(node, i) to GetItem(node, i - offset)
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
if (offset == 0) return;
MS_EXCEPTION_IF_NULL(getitem);
auto index = GetIndex(getitem);
@ -336,7 +221,7 @@ void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, siz
SetIndex(getitem, index);
}
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto old_maketuple = func_graph->output()->cast<CNodePtr>();
@ -379,7 +264,7 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
return graph_kernel_node;
}
bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) {
bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = FindGraphKernelsWithMultiOutput(func_graph);

@ -20,17 +20,54 @@
namespace mindspore {
namespace opt {
class EliminateRedundantOutput : public Pass {
/* Eliminate the output without external user
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) // the getitem(1) does not exist.
* %3 = op(%2)
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* %2 = Sub(p1, p2)
* return make_tuple(%1, %2)
* --->
* %1 = call @graph_kernel(p1, p2)
* %3 = op(%1) // if only one output remains, the getitem is not used
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* return %1 // the Sub was eliminated
*/
class EliminateHangingOutput : public Pass {
public:
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
~EliminateRedundantOutput() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const FuncGraphPtr &func_graph);
// update the GetItem(node, i) to GetItem(node, i - offset)
void UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset);
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
};
// Remove the output without user or with virtual user (like UpdateState)
class EliminateRedundantOutput : public Pass {
public:
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
~EliminateRedundantOutput() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
bool IsSideEffectNode(const AnfNodePtr &node);
AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph);
/**
* @brief Get the GraphKernel's user getitems
*
* @param mng FuncGraphManagerPtr for the main func_graph
* @param node The cnode that indicates the GraphKernel
* @param getitem_list The user getitem list.
* @param merge_repeated_getitem If true, getitems with same index will be merged,
* otherwise, only one getitem will be outputted.
* @return If the graph was changed, returns true, otherwise returns false.
*/
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_

@ -40,6 +40,7 @@
#include "backend/optimizer/graph_kernel/optimize_assign.h"
#include "backend/optimizer/graph_kernel/split_assign.h"
#include "backend/optimizer/graph_kernel/reorder_ops.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
@ -56,6 +57,9 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
if (is_ascend) {
pm->AddPass(std::make_shared<ReorderOps>());
}
// Spread the MakeTuple input of UpdateState
pm->AddPass(std::make_shared<SpreadUpdateState>());
return pm;
}
@ -99,6 +103,8 @@ PassManagerPtr GraphKernelOptimizer::Split() {
// Make certain nodes redundant so that they are used by only one user,
// which can avoid unnecessary input-output and get better performance.
if (is_gpu) {
// preprocess for ShapeOpsSplitter
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>());
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops));
}
@ -106,15 +112,16 @@ PassManagerPtr GraphKernelOptimizer::Split() {
// Split kernel according to costmodel
pm->AddPass(std::make_shared<GraphKernelSplitter>());
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<GetitemTuple>());
// Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
if (is_gpu) {
pm->AddPass(std::make_shared<MergeOutputForUpdateState>());
pm->AddPass(std::make_shared<GraphKernelCSE>());
pm->AddPass(std::make_shared<EliminateRedundantOutput>());
}
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<GetitemTuple>());
return pm;
}
@ -146,6 +153,9 @@ PassManagerPtr GraphKernelOptimizer::PostProcess() {
auto pm = std::make_shared<PassManager>("graphkernel_stage7_postprocess");
// Add the new tensors to the kernel_graph
pm->AddPass(std::make_shared<BindValueToGraph>());
// Make Tuple for the inputs of UpdateState. (the reverse of SpreadUpdateState)
pm->AddPass(std::make_shared<ShrinkUpdateState>());
return pm;
}
@ -163,6 +173,12 @@ void GraphKernelOptimizer::Run(const KernelGraphPtr &kernel_graph) {
optimizer->AddPassManager(HighLevelOpt2());
optimizer->AddPassManager(Combine());
optimizer->AddPassManager(PostProcess());
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
(void)optimizer->Optimize(kernel_graph);
}

@ -0,0 +1,163 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
/**
* @brief Spread the input tuple of UpdateState
* @example
* %1 = op1
* %2 = op2
* %3 = make_tuple(%1, %2)
* UpdateState(U, %3)
* -->
* %1 = op1
* %2 = op2
* UpdateState(U, %1, %2)
*/
class SpreadUpdateState : public Pass {
public:
SpreadUpdateState() : Pass("spread_update_state") {}
~SpreadUpdateState() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
/**
* @brief Shrink the inputs of UpdateState to a tuple
* @example
* %1 = op1
* %2 = op2
* UpdateState(U, %1, %2)
* -->
* %1 = op1
* %2 = op2
* %3 = make_tuple(%1, %2)
* UpdateState(U, %3)
*/
class ShrinkUpdateState : public Pass {
public:
ShrinkUpdateState() : Pass("shrink_update_state") {}
~ShrinkUpdateState() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
/**
* @brief Spread the MakeTuple in node list
* @param nodes
* @param begin_index
* @example
* input
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
* begin_index: 1
* output
* [b, i, j, c, d, x, y, z]
* @return std::vector<AnfNodePtr>
*/
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
/**
* @brief Extend the getitem for UpdateState
* @example
* In this example, the Cast is an output of GraphKernel and only links to an UpdateState,
* it has two users in GraphKernel, Add and Sub, which are all outputs.
* after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState.
*
* graph_kernel:
* %1 = Cast(p1)
* %2 = Add(%1, p2) // depends on Cast
* %3 = Sub(%2, p3) // depends on Cast
* %4 = Mul(p1, p2) // not depends on Cast
* return make_tuple(%1, %2, %3, %4)
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) // The Cast
* %3 = UpdateState(U, %2)
* -->
* graph_kernel:
* %1 = Cast(p1)
* %2 = Add(%1, p2) // depends on Cast
* %3 = Sub(%2, p3) // depends on Cast
* %4 = Mul(p1, p2) // not depends on Cast
* return make_tuple(%2, %3, %4) // the Cast was eliminated from output list
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) // the Add
* %3 = tuple_getitem(%1, 1) // the Sub
* %4 = UpdateState(U, %2, %3)
*/
class ExtendOutputForUpdateState : public Pass {
public:
ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {}
~ExtendOutputForUpdateState() = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
// Get the nodes that have external UpdateState user.
void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng);
void FilterIndexes(const FuncGraphPtr &func_graph);
// Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node.
std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index);
bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index);
enum ExternalUserType {
kNormalOp, // only has normal operators
kUpdateState, // only has UpdateState(s)
kMix, // UpdateState mix with normal operator
};
AnfNodePtrList getitems_; // Users of the GraphKernel nodes.
std::vector<size_t> indexes_; // Indexes of GetItem to be processed.
std::vector<ExternalUserType> external_user_type_; // The type of getitem's users.
};
/**
* @brief Merge UpdateState's inputs which link to the same node
* @example
* graph_kernel:
* %1 = Cast(p1)
* %2 = Add(%1, p2)
* %3 = Sub(%2, p3)
* %4 = Mul(p1, p2)
* return make_tuple(%1, %2, %3, %4)
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = tuple_getitem(%1, 2)
* %5 = UpdateState(U, %2, %3, %4) // the %2 %3 %4 are all link to %1
* -->
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = tuple_getitem(%1, 2)
* %5 = UpdateState(U, %2) // only keep %2
*/
class MergeOutputForUpdateState : public Pass {
public:
MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {}
~MergeOutputForUpdateState() = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
Loading…
Cancel
Save