fix code review bug

pull/2285/head
jjfeing 5 years ago
parent 90e05579c8
commit c26274f324

@ -37,9 +37,9 @@
namespace mindspore {
namespace device {
using device::ascend::ProfilingUtils;
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order();
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph->execution_order();
std::vector<CNodePtr> momentum_list;
std::vector<CNodePtr> other_list;
for (const auto &cnode : origin_cnode_list) {
@ -52,7 +52,7 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g
std::vector<CNodePtr> new_order_list;
new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end());
new_order_list.insert(new_order_list.end(), momentum_list.begin(), momentum_list.end());
kernel_graph_ptr->set_execution_order(new_order_list);
kernel_graph->set_execution_order(new_order_list);
}
void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {

@ -51,7 +51,7 @@ class KernelAdjust {
static KernelAdjust instance;
return instance;
}
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
bool StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr);

@ -103,6 +103,7 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
}
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::OutputAddrExist(kernel, index)) {
return true;
}
@ -217,6 +218,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
if (!ret) {
MS_LOG(EXCEPTION) << "Malloc device memory failed.";
@ -618,6 +620,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
input->addr = device_address->ptr_;

@ -68,6 +68,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
} else if (flag == kDynamicMem) {
ptr = MallocDynamicMem(size, false);
} else if (flag == kReuseDynamicMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
}
return ptr;
@ -75,6 +76,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) {
if (flag == kReuseDynamicMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
}
return MallocDynamicMem(size, false);

File diff suppressed because it is too large Load Diff

@ -47,6 +47,7 @@ bool TbeOpParallelPreBuild(const std::vector<AnfNodePtr> &anf_nodes) {
MS_EXCEPTION_IF_NULL(build_manger);
for (const auto &anf_node : anf_nodes) {
// gen kernel json
MS_EXCEPTION_IF_NULL(anf_node);
nlohmann::json kernel_json;
TbeKernelJsonCreator creator(OP_PRE_COMPILE);
if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) {

@ -223,6 +223,9 @@ constexpr char PACK[] = "Pack";
constexpr char GATHER_ND[] = "GatherNd";
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char ADD[] = "Add";
// Parallel don't care
constexpr char TUPLE_GETITEM[] = "tuple_getitem";

@ -257,6 +257,7 @@ void MemReuseUtil::SetKernelDefMap() {
void MemReuseUtil::SetKernelDefInputs() {
for (const auto &kernel : graph_->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
auto key = kernel.get();
// find kernel_def according to cnode addr
auto iter = kernel_map_.find(key);
@ -364,6 +365,7 @@ void MemReuseUtil::SetGraphOutputRefCount() {
void MemReuseUtil::ResetDynamicUsedRefCount() {
for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) {
for (auto &ref_count : iter->second) {
MS_EXCEPTION_IF_NULL(ref_count);
ref_count->ref_count_dynamic_use_ = ref_count->ref_count_;
}
}

Loading…
Cancel
Save