add labelswitch option

pull/10522/head
liangzelang 4 years ago
parent 2ab20c1b6e
commit 1de8a2fd5d

@ -1006,6 +1006,25 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
return node->has_default(); return node->has_default();
} }
bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName &&
(AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) {
return true;
} else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) {
return true;
}
}
return false;
}
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());

@ -188,6 +188,8 @@ class AnfRuntimeAlgorithm {
static bool IsNodeInGraphKernel(const AnfNodePtr &node); static bool IsNodeInGraphKernel(const AnfNodePtr &node);
// check parameter is weight or data // check parameter is weight or data
static bool IsParameterWeight(const ParameterPtr &node); static bool IsParameterWeight(const ParameterPtr &node);
// checkout whether the anf node is include the label_index.
static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
// set stream id of kernel,which will be set in stream assign and be used in stream generate // set stream id of kernel,which will be set in stream assign and be used in stream generate
static void SetStreamId(uint32_t stream_id, AnfNode *node); static void SetStreamId(uint32_t stream_id, AnfNode *node);
// get stream id // get stream id

@ -1238,7 +1238,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs";
int32_t index = 0; int32_t index = 0;
std::vector<KernelGraphPtr> child_graphs; std::vector<KernelGraphPtr> child_graphs;
auto start_label = graph->get_start_label(); auto start_label_id = AnfAlgo::GetNodeAttr<uint32_t>(graph->get_start_label(), kAttrLabelIndex);
auto end_node = graph->get_end_goto(); auto end_node = graph->get_end_goto();
ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0);
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
@ -1247,9 +1247,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
auto kg = graphs_[graph_id]; auto kg = graphs_[graph_id];
auto nodes = kg->execution_order(); auto nodes = kg->execution_order();
for (uint32_t i = 0; i < nodes.size(); i++) { for (uint32_t i = 0; i < nodes.size(); i++) {
if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName && if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) {
(AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) ==
AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) {
if (i < (nodes.size() - 1)) { if (i < (nodes.size() - 1)) {
new_inputs.push_back(nodes[i + 1]); new_inputs.push_back(nodes[i + 1]);
} else { } else {

Loading…
Cancel
Save