!218 Supporting shape inference for fused node which InferShapeFunc invokes GetInputConstData

Merge pull request !218 from 储星/subgrapjh_infershape
pull/218/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f7a1fcacdd

@ -15,6 +15,7 @@
*/
#include "hybrid/model/hybrid_model_builder.h"
#include <algorithm>
#include "common/math/math_util.h"
#include "graph/ge_context.h"
#include "graph/build/memory/var_mem_assign_util.h"
@ -58,6 +59,37 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
}
return var_size;
}
Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set<OpDesc *> &data_ops) {
for (const auto &node : node_item.fused_subgraph->nodes) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
const auto &depends = op_desc->GetOpInferDepends();
if (depends.empty()) {
continue;
}
for (auto &input_name : depends) {
auto input_index = op_desc->GetInputIndexByName(input_name);
auto src_node = NodeUtils::GetInDataNodeByIndex(*node, input_index);
GE_CHECK_NOTNULL(src_node);
auto src_op_desc = src_node->GetOpDesc();
GE_CHECK_NOTNULL(src_op_desc);
if (src_node->GetType() != DATA_TYPE) {
GELOGE(UNSUPPORTED,
"[%s::%s] Node in fused subgraph can only depend on Data nodes, but depend on %s",
node_item.NodeName().c_str(),
node->GetName().c_str(),
src_node->GetType().c_str());
return UNSUPPORTED;
}
data_ops.emplace(src_op_desc.get());
}
}
return SUCCESS;
}
} // namespace
HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model)
: hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) {
@ -272,6 +304,53 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
node_item.dependents_for_shape_inference.emplace_back(dep_node);
}
GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item));
return SUCCESS;
}
Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
if (node_item.fused_subgraph == nullptr) {
return SUCCESS;
}
std::set<OpDesc *> data_ops;
GE_CHK_STATUS_RET_NOLOG(CollectDependenciesForFusedGraph(node_item, data_ops));
for (auto &op_desc : data_ops) {
uint32_t parent_index = 0;
if (!AttrUtils::GetInt(*op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
GELOGE(INTERNAL_ERROR,
"[%s] Failed to get attr [%s]",
op_desc->GetName().c_str(),
ATTR_NAME_PARENT_NODE_INDEX.c_str());
return INTERNAL_ERROR;
}
const auto &in_anchor = node_item.node->GetInDataAnchor(parent_index);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
const auto &src_node = peer_out_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
NodeItem *src_node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(src_node, &src_node_item));
op_desc->SetId(src_node_item->op_desc->GetId());
GELOGD("[%s::%s] Node id was set to that of outer src node's, src_node = %s",
node_item.NodeName().c_str(),
op_desc->GetName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
auto &depends = node_item.dependents_for_shape_inference;
if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) {
depends.emplace_back(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}
}
return SUCCESS;
}

@ -62,6 +62,7 @@ class HybridModelBuilder {
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status IndexTaskDefs();
Status IndexSpecialNodes();
Status InitRuntimeParams();

Loading…
Cancel
Save