add force infershape for some op

pull/1224/head
wxl 4 years ago
parent 6aba1f7fad
commit c94e0fbdc6

@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
// Wait for "const input nodes" if node's shape inference function requires any.
// Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution
GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state));
if (node_item.is_output_shape_static) {
if (node_item.is_output_shape_static && node_item.is_need_force_infershape) {
return SUCCESS;
}

@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode";
const char *const kProfilingEndNode = "ProfilingEndNode";
const char *const kProfilingArNode = "ProfilingAllReduceNode";
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE";
const char *const kForceInfershape = "_force_infershape_when_running";
Status SetOutputNameAttr(ComputeGraph &graph) {
vector<string> output_names;
@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() {
Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item),
"[%s] Failed to parse force_infershape node.",
node_item.NodeName().c_str());
vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends();
GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies),
"[%s] Failed to parse node dependencies.",
@ -263,6 +267,15 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
return SUCCESS;
}
Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape);
return SUCCESS;
}
Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_input_nodes;
auto &ge_node = node_item.node;

@ -62,6 +62,7 @@ class HybridModelBuilder {
Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status IndexTaskDefs();

@ -83,6 +83,7 @@ struct NodeItem {
bool has_observer = false;
bool has_optional_inputs = false;
bool is_output_shape_static = true;
bool is_need_force_infershape = false;
UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE;
std::string node_name;
std::string node_type;

Loading…
Cancel
Save