1. Removed the deprecated pass "EliminateGetitemForControlDepend" 2. Spread the MakeTuple in UpdateState's input at PreProcess, so that all inputs are directly connected to UpdateState, I dont need to consider the condition "Getitem-MakeTuple-UpdateState'. after this pass, the UpdateState(U, make_tuple(op1, op2, ...)) was changed to UpdateState(U, op1, op2, ...) 3. Shrink the UpdateState's inputs at PostProcess. The reverse operation of the above pass. recovered the UpdateState's format for the process after GraphKernel. 4. Add a pass ExtendOutputForUpdateState, it's the main job of this commit. Consider this situation: A Cast op has multiple users in a composite kernel, while it's also in the output list and connects to an external UpdateState. In the pass "ShapeOpsSplitter", it will be duplicated. after that, only one replica will be connected to the external UpdateState, others will be connected to its original users respectively. After the pass "GraphKernelSplitter", only one part will be connected to this UpdateState, the execution order of other nodes cannot be ensured. This pass extended the node that connects to UpdateState, if a node has an external UpdateState user, all outputs that depend on this node will be connected to this UpdateState. It may add many redundant edges, the next pass will handle it. 5. Add a pass MergeOutputForUpdateState after GraphKernelSplitter. if an UpdateState has multiple inputs from the same node, only one edge will be kept.pull/12926/head
parent
c432105d8b
commit
49f78d5424
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,163 @@
|
||||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/**
|
||||
* @brief Spread the input tuple of UpdateState
|
||||
* @example
|
||||
* %1 = op1
|
||||
* %2 = op2
|
||||
* %3 = make_tuple(%1, %2)
|
||||
* UpdateState(U, %3)
|
||||
* -->
|
||||
* %1 = op1
|
||||
* %2 = op2
|
||||
* UpdateState(U, %1, %2)
|
||||
*/
|
||||
class SpreadUpdateState : public Pass {
|
||||
public:
|
||||
SpreadUpdateState() : Pass("spread_update_state") {}
|
||||
~SpreadUpdateState() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Shrink the inputs of UpdateState to a tuple
|
||||
* @example
|
||||
* %1 = op1
|
||||
* %2 = op2
|
||||
* UpdateState(U, %1, %2)
|
||||
* -->
|
||||
* %1 = op1
|
||||
* %2 = op2
|
||||
* %3 = make_tuple(%1, %2)
|
||||
* UpdateState(U, %3)
|
||||
*/
|
||||
class ShrinkUpdateState : public Pass {
|
||||
public:
|
||||
ShrinkUpdateState() : Pass("shrink_update_state") {}
|
||||
~ShrinkUpdateState() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Spread the MakeTuple in node list
|
||||
* @param nodes
|
||||
* @param begin_index
|
||||
* @example
|
||||
* input
|
||||
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
|
||||
* begin_index: 1
|
||||
* output
|
||||
* [b, i, j, c, d, x, y, z]
|
||||
* @return std::vector<AnfNodePtr>
|
||||
*/
|
||||
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
|
||||
|
||||
/**
|
||||
* @brief Extend the getitem for UpdateState
|
||||
* @example
|
||||
* In this example, the Cast is an output of GraphKernel and only links to an UpdateState,
|
||||
* it has two users in GraphKernel, Add and Sub, which are all outputs.
|
||||
* after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState.
|
||||
*
|
||||
* graph_kernel:
|
||||
* %1 = Cast(p1)
|
||||
* %2 = Add(%1, p2) // depends on Cast
|
||||
* %3 = Sub(%2, p3) // depends on Cast
|
||||
* %4 = Mul(p1, p2) // not depends on Cast
|
||||
* return make_tuple(%1, %2, %3, %4)
|
||||
* main graph:
|
||||
* %1 = call @graph_kernel(p1, p2)
|
||||
* %2 = tuple_getitem(%1, 0) // The Cast
|
||||
* %3 = UpdateState(U, %2)
|
||||
* -->
|
||||
* graph_kernel:
|
||||
* %1 = Cast(p1)
|
||||
* %2 = Add(%1, p2) // depends on Cast
|
||||
* %3 = Sub(%2, p3) // depends on Cast
|
||||
* %4 = Mul(p1, p2) // not depends on Cast
|
||||
* return make_tuple(%2, %3, %4) // the Cast was eliminated from output list
|
||||
* main graph:
|
||||
* %1 = call @graph_kernel(p1, p2)
|
||||
* %2 = tuple_getitem(%1, 0) // the Add
|
||||
* %3 = tuple_getitem(%1, 1) // the Sub
|
||||
* %4 = UpdateState(U, %2, %3)
|
||||
*/
|
||||
class ExtendOutputForUpdateState : public Pass {
|
||||
public:
|
||||
ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {}
|
||||
~ExtendOutputForUpdateState() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
// Get the nodes that have external UpdateState user.
|
||||
void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng);
|
||||
void FilterIndexes(const FuncGraphPtr &func_graph);
|
||||
// Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node.
|
||||
std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index);
|
||||
bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index);
|
||||
|
||||
enum ExternalUserType {
|
||||
kNormalOp, // only has normal operators
|
||||
kUpdateState, // only has UpdateState(s)
|
||||
kMix, // UpdateState mix with normal operator
|
||||
};
|
||||
AnfNodePtrList getitems_; // Users of the GraphKernel nodes.
|
||||
std::vector<size_t> indexes_; // Indexes of GetItem to be processed.
|
||||
std::vector<ExternalUserType> external_user_type_; // The type of getitem's users.
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Merge UpdateState's inputs which link to the same node
|
||||
* @example
|
||||
* graph_kernel:
|
||||
* %1 = Cast(p1)
|
||||
* %2 = Add(%1, p2)
|
||||
* %3 = Sub(%2, p3)
|
||||
* %4 = Mul(p1, p2)
|
||||
* return make_tuple(%1, %2, %3, %4)
|
||||
* main graph:
|
||||
* %1 = call @graph_kernel(p1, p2)
|
||||
* %2 = tuple_getitem(%1, 0)
|
||||
* %3 = tuple_getitem(%1, 1)
|
||||
* %4 = tuple_getitem(%1, 2)
|
||||
* %5 = UpdateState(U, %2, %3, %4) // the %2 %3 %4 are all link to %1
|
||||
* -->
|
||||
* main graph:
|
||||
* %1 = call @graph_kernel(p1, p2)
|
||||
* %2 = tuple_getitem(%1, 0)
|
||||
* %3 = tuple_getitem(%1, 1)
|
||||
* %4 = tuple_getitem(%1, 2)
|
||||
* %5 = UpdateState(U, %2) // only keep %2
|
||||
*/
|
||||
class MergeOutputForUpdateState : public Pass {
|
||||
public:
|
||||
MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {}
|
||||
~MergeOutputForUpdateState() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
|
Loading…
Reference in new issue