fix cpu reshape bug

pull/2174/head
kswang 5 years ago
parent f10e297498
commit 236d6c6de4

@ -161,8 +161,12 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz
}
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(tensor);
address->ptr_ = tensor->data_c(true);
address->ref_count_ = INIT_NODE_REF;
if (address->ref_count_ > 0 && address->ptr_ != nullptr) {
tensor->set_device_address(address);
} else {
address->ptr_ = tensor->data_c(true);
address->ref_count_ = INIT_NODE_REF;
}
tensor->set_dirty(false);
return tensor;
} else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
@ -211,6 +215,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
}
tensor->set_dirty(true);
}
address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address);
}
@ -220,7 +225,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
// new output and bind ptr
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
auto out = CreatTensorForOutput(item_with_index.first, item_with_index.second, input_map);
outputs->push_back(std::move(out));
}

Loading…
Cancel
Save