Move label maker to stage2.

pull/224/head
unknown 4 years ago
parent 191215d278
commit 5d9f4fd2fc

@ -26,12 +26,17 @@
namespace ge {
LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {}
Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) {
Status LabelAllocator::AssignFunctionalLabels() {
if (compute_graph_ == nullptr) {
GELOGE(INTERNAL_ERROR, "ComputeGraph not set, Assign labels failed.");
return INTERNAL_ERROR;
}
if (compute_graph_->GetGraphUnknownFlag()) {
GELOGD("Graph[%s] is unknown graph, skip label allocator.", compute_graph_->GetName().c_str());
return SUCCESS;
}
// Add label task for sub graph.
GELOGI("AssignFunctionalLabels start: %s.", compute_graph_->GetName().c_str());
std::set<NodePtr> functional_nodes;
@ -42,7 +47,7 @@ Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) {
}
// Add label for functional op.
label_index = 0;
uint32_t label_index = 0;
for (auto node : functional_nodes) {
LabelMakerPtr maker = LabelMakerFactory::Instance().Create(node->GetType(), compute_graph_, node);
if (maker == nullptr) {
@ -56,6 +61,7 @@ Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) {
}
}
(void)AttrUtils::SetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_index);
GELOGI("AssignFunctionalLabels success.");
return SUCCESS;
}

@ -28,7 +28,7 @@ class LabelAllocator {
explicit LabelAllocator(const ComputeGraphPtr &graph);
~LabelAllocator() = default;
Status AssignFunctionalLabels(uint32_t &label_index);
Status AssignFunctionalLabels();
private:
bool CollectFunctionalNode(ComputeGraphPtr &graph, std::set<NodePtr> &functional_nodes);

@ -348,7 +348,11 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr
auto compute_graph = subgraph->subgraph_info.GetSubGraph();
for (NodePtr &node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
if (IsEngineSkip(*subgraph) && node->GetInNodes().empty()) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_RTS_LABEL_NODE)) {
node->GetOpDesc()->SetStreamId(context.default_stream);
GELOGD("Node %s of type %s in subgraph %s is assigned parent stream %ld (engine: %s).", node->GetName().c_str(),
node->GetType().c_str(), subgraph->name.c_str(), context.default_stream, engine_name.c_str());
} else if (IsEngineSkip(*subgraph) && node->GetInNodes().empty()) {
GELOGD("Node %s of type %s in subgraph %s doesn't need to assign a stream (engine: %s).",
node->GetName().c_str(), node->GetType().c_str(), subgraph->name.c_str(), engine_name.c_str());
} else {

@ -23,7 +23,6 @@
#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/buffer.h"
#include "graph/build/label_allocator.h"
#include "graph/build/stream_allocator.h"
#include "graph/common/omg_util.h"
#include "graph/common/ge_call_wrapper.h"
@ -42,7 +41,6 @@
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/passes/memcpy_addr_async_pass.h"
#include "init/gelib.h"
#include "memory/memory_assigner.h"
#include "omg/version.h"
@ -692,25 +690,8 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) {
GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams");
// Assign functional op labels.
GE_TIMESTAMP_START(AssignFunctionalLabels);
LabelAllocator label_allocator(compute_graph_);
GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed.");
GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels");
// Add memcpy_addr_async node.
rtFeatureType_t feature_type = FEATURE_TYPE_MEMCPY;
int32_t feature_info = MEMCPY_INFO_SUPPORT_ZEROCOPY;
int64_t value = 0;
rtError_t rt_ret = rtGetRtCapability(feature_type, feature_info, &value);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtGetRtCapability failed.");
return RT_FAILED;
} else {
GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode);
MemcpyAddrAsyncPass memcpy_addr;
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph_), "Add memcpy_addr_async node failed.");
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run.");
}
label_num_ = 0;
(void)AttrUtils::GetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_num_);
GE_TIMESTAMP_START(AssignMemory);
MemoryAssigner mem_assigner(compute_graph_);

File diff suppressed because it is too large Load Diff

@ -60,9 +60,8 @@ class LabelMaker {
ComputeGraphPtr parent_graph_;
private:
void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
void LinkToGraphHead(const ComputeGraphPtr &graph, const NodePtr &node);
void LinkToGraphTail(const ComputeGraphPtr &graph, const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_LABEL_MAKER_H_

@ -100,6 +100,8 @@
#include "graph/passes/subgraph_const_migration_pass.h"
#include "graph/passes/unused_args_clean_pass.h"
#include "graph/passes/global_step_insert_pass.h"
#include "graph/passes/memcpy_addr_async_pass.h"
#include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h"
#include "graph/utils/type_utils.h"
#include "graph/graph_util.h"
@ -634,6 +636,13 @@ Status GraphManager::PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node,
GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts",
GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts,
compute_graph);
Status ret = compute_graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "Graph topological sort failed, ret:%d.", ret);
return ret;
}
GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id);
GELOGI("PreRun:PreRunAfterOptimizeSubGraph success.");
return SUCCESS;
@ -2186,6 +2195,18 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
return ret;
}
// Assign functional op labels.
GE_TIMESTAMP_START(AssignFunctionalLabels);
LabelAllocator label_allocator(compute_graph);
GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(), "Assign label failed.");
GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels");
// Add memcpy addr asynchronous node.
GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode);
MemcpyAddrAsyncPass memcpy_addr;
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed.");
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run.");
// After while sub graph handle, mark all node rw type
auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph);
if (result != SUCCESS) {
@ -2196,11 +2217,6 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
ChangeConstTypeWhenTraining(compute_graph);
ret = compute_graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "Graph topological sort failed, ret:%d.", ret);
return ret;
}
GELOGI("End optimize after merge sub graph.");
return SUCCESS;
}

@ -25,6 +25,14 @@
namespace ge {
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
int64_t value = 0;
rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtGetRtCapability failed, error=0x%x.", rt_ret);
return RT_FAILED;
}
for (auto &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
@ -210,9 +218,18 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr
return nullptr;
}
int64_t stream_id = out_of_user_data->GetOpDesc()->GetStreamId();
op_desc->SetStreamId(stream_id);
GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id);
string stream_label;
if (AttrUtils::GetStr(out_of_user_data->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
(void)AttrUtils::SetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label);
GELOGD("Node %s set stream label: %s", op_desc->GetName().c_str(), stream_label.c_str());
}
bool rts_label_node = false;
if (AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_RTS_LABEL_NODE, rts_label_node)) {
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, rts_label_node);
GELOGD("Node %s set rts label node attribute", op_desc->GetName().c_str());
}
bool labeled_input = false;
(void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input);
if (labeled_input) {

Loading…
Cancel
Save