|
|
|
@ -226,7 +226,10 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
|
|
|
|
|
<< AnfAlgo::GetInputTensorNum(kernel);
|
|
|
|
|
}
|
|
|
|
|
auto input_node = kernel->input(input_idx + 1);
|
|
|
|
|
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0);
|
|
|
|
|
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
|
|
|
|
if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
|
|
|
|
|
}
|
|
|
|
|
auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second));
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
@ -264,7 +267,10 @@ void MemReuseUtil::SetKernelDefInputs() {
|
|
|
|
|
if (ref_ptr != nullptr) {
|
|
|
|
|
// set the inputs of this kernel_def
|
|
|
|
|
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
|
|
|
|
auto input = AnfAlgo::VisitKernel(input_node, 0);
|
|
|
|
|
auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
|
|
|
|
if (IsPrimitive(input.first, prim::kPrimMakeTuple)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
|
|
|
|
|
}
|
|
|
|
|
auto input_key = (input.first).get();
|
|
|
|
|
auto input_iter = kernel_map_.find(input_key);
|
|
|
|
|
if (input_iter == kernel_map_.end()) {
|
|
|
|
|