You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
4.2 KiB
88 lines
4.2 KiB
/**
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
|
|
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
|
|
|
|
#include <set>
|
|
#include <map>
|
|
#include <vector>
|
|
#include <tuple>
|
|
#include "session/kernel_graph.h"
|
|
#include "utils/base_ref.h"
|
|
#include "utils/contract.h"
|
|
|
|
namespace mindspore {
|
|
namespace session {
|
|
|
|
class AscendControlParser {
|
|
public:
|
|
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
|
|
static void LinkGraph(NotNull<KernelGraphPtr> kg);
|
|
|
|
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
|
|
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
|
|
NotNull<AnfNodePtr> second_node);
|
|
static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph);
|
|
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
|
|
|
|
private:
|
|
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
|
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo);
|
|
static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
|
NotNull<std::set<KernelGraphPtr> *> memo);
|
|
static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
|
NotNull<std::set<KernelGraphPtr> *> memo);
|
|
static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
|
NotNull<std::set<KernelGraphPtr> *> memo);
|
|
|
|
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
|
|
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo);
|
|
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
|
|
|
|
static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
|
|
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param);
|
|
static NotNull<AnfNodePtr> GetRealInput(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
|
|
NotNull<AnfNodePtr> param);
|
|
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
|
|
|
static CNodePtr GetNextRealKernel(std::vector<CNodePtr> list, size_t start);
|
|
|
|
// root graph order
|
|
static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode(
|
|
const std::vector<CNodePtr> &nodes);
|
|
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
|
|
NotNull<KernelGraphPtr> graph);
|
|
static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
|
|
NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
|
|
|
|
static constexpr size_t kCNodePrim = 0;
|
|
static constexpr size_t kCNodeCallArg = 1;
|
|
static constexpr size_t kCNodeSwitchCond = 1;
|
|
static constexpr size_t kCNodeSwitchTrue = 2;
|
|
static constexpr size_t kCNodeSwitchFalse = 3;
|
|
static constexpr size_t kCNodeSwitchLength = 4;
|
|
static constexpr size_t kCNodePartialLength = 2;
|
|
static constexpr size_t kCNodePartialFunc = 1;
|
|
static constexpr size_t kCNodeSwitchLayerCond = 1;
|
|
static constexpr size_t kCNodeSwitchLayerBranch = 2;
|
|
static constexpr size_t kCNodeSwitchLayerLength = 3;
|
|
};
|
|
|
|
} // namespace session
|
|
} // namespace mindspore
|
|
|
|
#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
|