infer shape for noop

pull/9378/head
wilfChen 4 years ago
parent 0c7ba7a7fa
commit 0363a7401a

@ -41,7 +41,6 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/profiling/profiling_utils.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/ascend_memory_manager.h"
#include "debug/tensor_load.h"
#include "debug/data_dump/dump_json_parser.h"
@ -114,34 +113,6 @@ std::string GetRankId() {
}
return rank_id_str;
}
void InferShapeForNopNode(AnfNodePtr *input_node) {
MS_EXCEPTION_IF_NULL(*input_node);
if (!opt::IsNopNode(*input_node)) {
MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
return;
}
MS_LOG(INFO) << "Infer shape for nop node.";
std::stack<AnfNodePtr> nop_road;
nop_road.push(*input_node);
while (true) {
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
auto in_node = input_node_with_idx.first;
MS_EXCEPTION_IF_NULL(in_node);
if (opt::IsNopNode(in_node)) {
nop_road.push(in_node);
*input_node = in_node;
} else {
break;
}
}
while (!nop_road.empty()) {
auto nop_node = nop_road.top();
AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
nop_road.pop();
}
}
} // namespace
std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_;
@ -665,15 +636,6 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap
}
if (dynamic_kernel->is_dynamic_shape()) {
auto kernel_node = dynamic_kernel->kernel_node();
MS_EXCEPTION_IF_NULL(kernel_node);
auto input_size = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
InferShapeForNopNode(&input_node);
}
dynamic_kernel->InferShape();
dynamic_kernel->UpdateArgs();
}

@ -18,6 +18,7 @@
#include <vector>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "common/trans.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "abstract/dshape.h"
@ -73,6 +74,7 @@ void DynamicKernel::InferShape() {
}
MS_EXCEPTION_IF_NULL(cnode_ptr_);
MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope();
InferShapeRecursive();
auto inputs = cnode_ptr_->inputs();
if (inputs.empty()) {
@ -124,5 +126,43 @@ void DynamicKernel::InferShape() {
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
cnode_ptr_->set_abstract(eval_result);
}
void DynamicKernel::InferShapeRecursive() {
auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_);
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
InferShapeForNopNode(&input_node);
}
}
void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) {
MS_EXCEPTION_IF_NULL(*input_node);
if (!opt::IsNopNode(*input_node) || !AnfAlgo::IsDynamicShape(*input_node)) {
MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
return;
}
MS_LOG(INFO) << "Infer shape for nop node.";
std::stack<AnfNodePtr> nop_road;
nop_road.push(*input_node);
while (true) {
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
auto in_node = input_node_with_idx.first;
MS_EXCEPTION_IF_NULL(in_node);
if (opt::IsNopNode(in_node)) {
nop_road.push(in_node);
*input_node = in_node;
} else {
break;
}
}
while (!nop_road.empty()) {
auto nop_node = nop_road.top();
AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
nop_road.pop();
}
}
} // namespace device
} // namespace mindspore

@ -52,6 +52,8 @@ class DynamicKernel {
protected:
void RebuildDependTensor();
void InferShapeRecursive();
void InferShapeForNopNode(AnfNodePtr *input_node);
void *stream_;
const CNodePtr cnode_ptr_;

Loading…
Cancel
Save