diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 5960da1a..306b3eda 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -125,6 +125,7 @@ set(TRAIN_SRC_LIST "graph/manager/graph_var_manager.cc" "graph/manager/host_mem_manager.cc" "graph/manager/rdma_pool_allocator.cc" + $<$>:graph/manager/host_mem_allocator.cc> "graph/manager/memory_api.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/trans_var_data_utils.cc" @@ -165,7 +166,8 @@ set(TRAIN_SRC_LIST "graph/passes/dropout_pass.cc" "graph/passes/hccl_group_pass.cc" "graph/passes/enter_pass.cc" - "graph/passes/assign_pass.cc" + "graph/passes/assign_remove_pass.cc" + $<$>:graph/passes/inplace_support_check_pass.cc> "graph/passes/flow_ctrl_pass.cc" "graph/passes/global_step_insert_pass.cc" "host_kernels/transpose_kernel.cc" @@ -401,6 +403,7 @@ set(INFER_SRC_LIST "graph/manager/graph_var_manager.cc" "graph/manager/host_mem_manager.cc" "graph/manager/rdma_pool_allocator.cc" + $<$>:graph/manager/host_mem_allocator.cc> "graph/manager/graph_mem_allocator.cc" "graph/manager/graph_caching_allocator.cc" "model/ge_model.cc" @@ -521,7 +524,8 @@ set(INFER_SRC_LIST "graph/passes/cond_remove_pass.cc" "graph/passes/for_pass.cc" "graph/passes/enter_pass.cc" - "graph/passes/assign_pass.cc" + "graph/passes/assign_remove_pass.cc" + $<$>:graph/passes/inplace_support_check_pass.cc> "graph/passes/addn_pass.cc" "graph/passes/common_subexpression_elimination_pass.cc" "graph/passes/remove_same_const_pass.cc" diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index 8e548815..4ca18864 100644 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -28,6 +28,7 @@ set(SRC_LIST "../graph/manager/trans_var_data_utils.cc" "../graph/manager/util/debug.cc" "../graph/manager/rdma_pool_allocator.cc" + $<$>:../graph/manager/host_mem_allocator.cc> "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" "../model/ge_model.cc" "../model/ge_root_model.cc" diff --git a/ge/executor/module.mk b/ge/executor/module.mk index 34c2a37e..87abdade 100644 --- a/ge/executor/module.mk +++ b/ge/executor/module.mk @@ -15,6 +15,7 @@ local_ge_executor_src_files := \ ../graph/manager/graph_manager_utils.cc \ ../graph/manager/graph_var_manager.cc \ ../graph/manager/rdma_pool_allocator.cc \ + ../graph/manager/host_mem_allocator.cc \ ../graph/manager/graph_mem_allocator.cc \ ../graph/manager/graph_caching_allocator.cc \ ../graph/manager/trans_var_data_utils.cc \ diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index e20456d5..74d09404 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -64,6 +64,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ graph/manager/graph_var_manager.cc \ graph/manager/host_mem_manager.cc \ graph/manager/rdma_pool_allocator.cc \ + graph/manager/host_mem_allocator.cc \ graph/manager/graph_mem_allocator.cc \ graph/manager/graph_caching_allocator.cc \ @@ -195,7 +196,8 @@ OMG_HOST_SRC_FILES := \ graph/passes/useless_control_out_remove_pass.cc \ graph/passes/for_pass.cc \ graph/passes/enter_pass.cc \ - graph/passes/assign_pass.cc \ + graph/passes/assign_remove_pass.cc \ + graph/passes/inplace_support_check_pass.cc \ graph/passes/addn_pass.cc \ graph/passes/common_subexpression_elimination_pass.cc \ graph/passes/transop_symmetry_elimination_pass.cc \ diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index e17f73de..0f46b4cb 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -26,6 +26,31 @@ #include "common/math/math_util.h" namespace { +#ifndef ONLY_COMPILE_OPEN_SRC +#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ + case (DTYPE): { \ + GeTensorPtr ge_tensor = nullptr; \ + if (need_create_flag) { \ + uint64_t size = data_num * sizeof(TYPE); \ + ge_tensor = MakeShared(out_desc, size); \ + GE_CHECK_NOTNULL(ge_tensor); \ + GELOGD("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, size); \ + ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ + ge_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); \ + outputs.emplace_back(ge_tensor); \ + } else { \ + ge_tensor = outputs[i]; \ + GE_CHECK_NOTNULL(ge_tensor); \ + GELOGD("node:%s existed output %zu", op_desc->GetName().c_str(), i); \ + } \ + auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ + auto tensor_name = op_desc->GetOutputNameByIndex(i); \ + GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", \ + op_desc->GetName().c_str(), i); \ + named_outputs.emplace(tensor_name, tensor); \ + break; \ + } +#else #define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ case (DTYPE): { \ GeTensorPtr ge_tensor = nullptr; \ @@ -61,6 +86,7 @@ namespace { named_outputs.emplace(tensor_name, tensor); \ break; \ } +#endif } namespace ge { diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 9706dadb..5a99dc8c 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -94,6 +94,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/manager/graph_var_manager.cc \ graph/manager/host_mem_manager.cc \ graph/manager/rdma_pool_allocator.cc \ + graph/manager/host_mem_allocator.cc \ graph/manager/memory_api.cc \ graph/manager/model_manager/event_manager.cc \ graph/manager/trans_var_data_utils.cc \ @@ -134,7 +135,8 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/dropout_pass.cc \ graph/passes/hccl_group_pass.cc \ graph/passes/enter_pass.cc \ - graph/passes/assign_pass.cc \ + graph/passes/assign_remove_pass.cc \ + graph/passes/inplace_support_check_pass.cc \ graph/passes/flow_ctrl_pass.cc \ graph/passes/global_step_insert_pass.cc \ host_kernels/transpose_kernel.cc \ diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index f7646f30..030b864e 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -38,6 +38,10 @@ #include "graph/partition/stage_partition.h" #include "graph/passes/addn_pass.h" #include "graph/passes/bitcast_pass.h" +#ifndef ONLY_COMPILE_OPEN_SRC +#include "graph/passes/assign_remove_pass.h" +#include "graph/passes/inplace_support_check_pass.h" +#endif #include "graph/passes/atomic_addr_clean_pass.h" #include "graph/passes/attach_stream_label_pass.h" #include "graph/passes/cast_remove_pass.h" @@ -2250,10 +2254,20 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { ReshapeRemovePass reshape_remove_pass; CondRemovePass condition_remove_pass; BitcastPass bitcast_pass; +#ifndef ONLY_COMPILE_OPEN_SRC + AssignRemovePass assign_remove_pass; + InplaceSupportCheckPass inplace_support_check_pass; +#endif names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); names_to_passes.emplace_back("BitcastPass", &bitcast_pass); +#ifndef ONLY_COMPILE_OPEN_SRC + if (GetContext().GetHostExecFlag()) { + names_to_passes.emplace_back("AssignRemovePass", &assign_remove_pass); + names_to_passes.emplace_back("InplaceSupportCheckPass", &inplace_support_check_pass); + } +#endif GE_TIMESTAMP_START(names_to_passes); ret = GEPass(compute_graph).Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "OptimizeStage2::MergedGraphNameToPasses"); diff --git a/ge/graph/manager/graph_mem_allocator.cc b/ge/graph/manager/graph_mem_allocator.cc index 7ee7df20..4e31d835 100755 --- a/ge/graph/manager/graph_mem_allocator.cc +++ b/ge/graph/manager/graph_mem_allocator.cc @@ -19,7 +19,9 @@ #include #include "graph/manager/graph_caching_allocator.h" #include "graph/manager/rdma_pool_allocator.h" - +#ifndef ONLY_COMPILE_OPEN_SRC +#include "graph/manager/host_mem_allocator.h" +#endif namespace ge { void MemoryAllocator::Initialize(uint32_t device_id) { GELOGI("MemoryAllocator::Initialize"); @@ -190,6 +192,12 @@ Status MemManager::Initialize(const std::vector &memory_type) { GELOGE(ge::INTERNAL_ERROR, "Create RdmaAllocator failed."); return ge::INTERNAL_ERROR; } +#ifndef ONLY_COMPILE_OPEN_SRC + if (InitAllocator(memory_type, host_allocator_map_) != SUCCESS) { + GELOGE(ge::INTERNAL_ERROR, "Create HostMemAllocator failed."); + return ge::INTERNAL_ERROR; + } +#endif return SUCCESS; } @@ -211,6 +219,9 @@ void MemManager::Finalize() noexcept { // caching and rdma allocator use memory allocator, so finalize them first FinalizeAllocatorMap(caching_allocator_map_); FinalizeAllocatorMap(rdma_allocator_map_); +#ifndef ONLY_COMPILE_OPEN_SRC + FinalizeAllocatorMap(host_allocator_map_); +#endif FinalizeAllocatorMap(memory_allocator_map_); } @@ -239,4 +250,9 @@ CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { return Instance().GetAllocator(memory_type, rdma_allocator_map_); } +#ifndef ONLY_COMPILE_OPEN_SRC +HostMemAllocator &MemManager::HostMemInstance(rtMemType_t memory_type) { + return Instance().GetAllocator(memory_type, host_allocator_map_); +} +#endif } // namespace ge diff --git a/ge/graph/manager/graph_mem_allocator.h b/ge/graph/manager/graph_mem_allocator.h index 2723ae5c..6cdbd9b4 100644 --- a/ge/graph/manager/graph_mem_allocator.h +++ b/ge/graph/manager/graph_mem_allocator.h @@ -139,7 +139,9 @@ class MemoryAllocator { using MemoryAllocatorPtr = std::shared_ptr; class CachingAllocator; class RdmaPoolAllocator; - +#ifndef ONLY_COMPILE_OPEN_SRC +class HostMemAllocator; +#endif class MemManager { public: MemManager(); @@ -148,6 +150,9 @@ class MemManager { static MemoryAllocator *Instance(rtMemType_t memory_type); CachingAllocator &CachingInstance(rtMemType_t memory_type); RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); +#ifndef ONLY_COMPILE_OPEN_SRC + HostMemAllocator &HostMemInstance(rtMemType_t memory_type); +#endif MemManager(const MemManager &) = delete; MemManager &operator=(const MemManager &) = delete; /// @@ -235,6 +240,9 @@ class MemManager { std::map memory_allocator_map_; std::map caching_allocator_map_; std::map rdma_allocator_map_; +#ifndef ONLY_COMPILE_OPEN_SRC + std::map host_allocator_map_; +#endif std::recursive_mutex allocator_mutex_; }; } // namespace ge diff --git a/ge/graph/manager/host_mem_allocator.cc b/ge/graph/manager/host_mem_allocator.cc new file mode 100644 index 00000000..ca2b5124 --- /dev/null +++ b/ge/graph/manager/host_mem_allocator.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/manager/host_mem_allocator.h" +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" + +namespace ge { +const void *HostMemAllocator::Malloc(const std::shared_ptr &aligned_ptr, size_t size) { + if (aligned_ptr == nullptr) { + GELOGW("Insert a null aligned_ptr"); + return nullptr; + } + GELOGD("allocate existed host memory succ, size=%zu", size); + allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; + return aligned_ptr->Get(); +} + +uint8_t *HostMemAllocator::Malloc(size_t size) { + GELOGD("start to malloc host memory, size=%zu", size); + std::lock_guard lock(mutex_); + std::shared_ptr aligned_ptr = MakeShared(size); + if (aligned_ptr == nullptr) { + GELOGE(INTERNAL_ERROR, "make shared_ptr for AlignedPtr failed"); + return nullptr; + } + allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; + GELOGD("allocate host memory succ, size=%zu", size); + return aligned_ptr->MutableGet(); +} + +Status HostMemAllocator::Free(const void *memory_addr) { + if (memory_addr == nullptr) { + GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); + return GE_GRAPH_FREE_FAILED; + } + + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(memory_addr); + if (it == allocated_blocks_.end()) { + GELOGE(PARAM_INVALID, "Invalid memory pointer"); + return PARAM_INVALID; + } + it->second.second.reset(); + allocated_blocks_.erase(it); + + return SUCCESS; +} + +void HostMemAllocator::Clear() { + for (auto &block : allocated_blocks_) { + block.second.second.reset(); + } + allocated_blocks_.clear(); +} +} // namespace ge diff --git a/ge/graph/manager/host_mem_allocator.h b/ge/graph/manager/host_mem_allocator.h new file mode 100644 index 00000000..b9dbdc4c --- /dev/null +++ b/ge/graph/manager/host_mem_allocator.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_MANAGER_HOST_MEM_ALLOCATOR_H_ +#define GE_GRAPH_MANAGER_HOST_MEM_ALLOCATOR_H_ + +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "graph/aligned_ptr.h" +#include "runtime/mem.h" + +namespace ge { +class HostMemAllocator { + public: + explicit HostMemAllocator(rtMemType_t) {} + ~HostMemAllocator() = default; + + HostMemAllocator(const HostMemAllocator &) = delete; + HostMemAllocator &operator=(const HostMemAllocator &) = delete; + + Status Initialize() { + Clear(); + return SUCCESS; + } + void Finalize() { Clear(); } + + const void *Malloc(const std::shared_ptr& aligned_ptr, size_t size); + uint8_t *Malloc(size_t size); + Status Free(const void *memory_addr); + + std::pair> GetAlignedPtr(const void *addr) { return allocated_blocks_[addr]; } + + private: + void Clear(); + + std::map>> allocated_blocks_; + // lock around all operations + mutable std::mutex mutex_; +}; +} // namespace ge + +#endif // GE_GRAPH_MANAGER_HOST_MEM_ALLOCATOR_H_ diff --git a/ge/graph/manager/host_mem_manager.cc b/ge/graph/manager/host_mem_manager.cc index c99c9e87..c9a33f5c 100644 --- a/ge/graph/manager/host_mem_manager.cc +++ b/ge/graph/manager/host_mem_manager.cc @@ -43,16 +43,29 @@ Status SharedMemAllocator::Allocate(SharedMemInfo &mem_info) { return GE_GRAPH_MEMORY_ALLOC_FAILED; } mem_info.fd = output_para.fd; +#ifndef ONLY_COMPILE_OPEN_SRC + mem_info.host_aligned_ptr = AlignedPtr::BuildFromAllocFunc([&output_para](std::unique_ptr &ptr) { + ptr.reset(reinterpret_cast(output_para.ptr)); + }, + [](uint8_t *ptr) { + ptr = nullptr; + }); +#else mem_info.host_address = reinterpret_cast(output_para.ptr); +#endif mem_info.device_address = reinterpret_cast(output_para.devPtr); return SUCCESS; } Status SharedMemAllocator::DeAllocate(SharedMemInfo &mem_info) { GELOGD("SharedMemAllocator::DeAllocate"); +#ifndef ONLY_COMPILE_OPEN_SRC + rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, + mem_info.host_aligned_ptr->MutableGet(), mem_info.device_address}; +#else rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, mem_info.host_address, mem_info.device_address}; - +#endif rtError_t rt_ret = rtFreeHostSharedMemory(&free_para); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api(rtFreeHostSharedMemory) failed, ret: 0x%X.", rt_ret); diff --git a/ge/graph/manager/host_mem_manager.h b/ge/graph/manager/host_mem_manager.h index 66bd5826..f204c9e4 100644 --- a/ge/graph/manager/host_mem_manager.h +++ b/ge/graph/manager/host_mem_manager.h @@ -42,7 +42,11 @@ struct SharedMemInfo { uint64_t mem_size = 0; int fd = 0; uint8_t *device_address = nullptr; +#ifndef ONLY_COMPILE_OPEN_SRC + std::shared_ptr host_aligned_ptr = nullptr; +#else uint8_t *host_address = nullptr; +#endif SharedMemInfo() = default; SharedMemInfo(string name, uint64_t size) : op_name(std::move(name)), mem_size(size) {} }; diff --git a/ge/graph/passes/assign_pass.cc b/ge/graph/passes/assign_pass.cc deleted file mode 100644 index bb7a0f04..00000000 --- a/ge/graph/passes/assign_pass.cc +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/assign_pass.h" - -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" - -namespace { -const uint32_t kValidInputNodeOutputNum = 1; -const int32_t kAssignRefInputIndex = 0; -const int32_t kAssignValueInputIndex = 1; -} - -namespace ge { -Status AssignPass::Run(NodePtr &node) { - GELOGD("AssignPass running"); - if (node->GetType() != ASSIGN) { - GELOGD("No need run AssignPass on [%s, %s].", node->GetName().c_str(), node->GetType().c_str()); - return SUCCESS; - } - - const auto &ref_in_anchor = node->GetInDataAnchor(kAssignRefInputIndex); - const auto &value_in_anchor = node->GetInDataAnchor(kAssignValueInputIndex); - if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) { - GELOGE(FAILED, "In data anchor is null, node:%s", node->GetName().c_str()); - return FAILED; - } - const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor(); - const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor(); - if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) { - GELOGE(FAILED, "Peer data anchor is null, node:%s", node->GetName().c_str()); - return FAILED; - } - - if (IsCondMatch(node, ref_peer_anchor, value_peer_anchor)) { - /// - /// variable not-const not-const - /// \ / | - /// \ / | - /// Assign ----> variable - /// | | - /// | | - /// node node - /// - GELOGI("Optimization for assign_node %s start", node->GetName().c_str()); - if (IsolateAndDeleteNode(node, {kAssignRefInputIndex}) != SUCCESS) { - GELOGE(FAILED, "Isolate and delete assign_node %s failed.", node->GetName().c_str()); - return FAILED; - } - AddNodeDeleted(node); - - const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); - const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); - if ((ref_input == nullptr) || (value_input == nullptr)) { - GELOGE(FAILED, "value input is null"); - return FAILED; - } - if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME, - ref_input->GetName())) { - GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); - return FAILED; - } - - // variable has and only has one input - if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str()); - return FAILED; - } - if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str()); - return FAILED; - } - } - - GELOGD("AssignPass success"); - return SUCCESS; -} - -/// -/// @brief Check if need optimize for assign_node -/// @param [in] assign_node -/// @param [in] peer_data_anchor for ref_input of assign_node -/// @param [in] peer_data_anchor for value_input of assign_node -/// @return Status -/// -bool AssignPass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor, - const OutDataAnchorPtr &value_peer_anchor) { - GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s", - node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(), - value_peer_anchor->GetOwnerNode()->GetName().c_str()); - - const std::string &value_type = value_peer_anchor->GetOwnerNode()->GetType(); - if ((value_type == CONSTANTOP) || (value_type == CONSTANT)) { - GELOGD("value input is const"); - return false; - } - - const std::string &ref_type = ref_peer_anchor->GetOwnerNode()->GetType(); - if ((ref_type != VARIABLE) && (ref_type != VARIABLEV2)) { - GELOGD("ref input is not var"); - return false; - } - if (!ref_peer_anchor->GetOwnerNode()->GetInDataNodes().empty()) { - GELOGD("ref input has data input"); - return false; - } - - if ((ref_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum) || - (value_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum)) { - GELOGD("ref / value input has other output(s)"); - return false; - } - - GELOGD("Optimization condition matches, assign_node: %s", node->GetName().c_str()); - return true; -} -} // namespace ge diff --git a/ge/graph/passes/assign_remove_pass.cc b/ge/graph/passes/assign_remove_pass.cc new file mode 100644 index 00000000..5029b9c3 --- /dev/null +++ b/ge/graph/passes/assign_remove_pass.cc @@ -0,0 +1,250 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/assign_remove_pass.h" +#include "framework/common/debug/log.h" +#include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" + +namespace { +constexpr uint32_t kValidInputNodeOutputNum = 1; +constexpr int32_t kAssignRefInputIndex = 0; +constexpr int32_t kAssignValueInputIndex = 1; +static const std::set kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, + ge::CONSTANT, ge::CONSTANTOP, + ge::VARIABLE, ge::VARIABLEV2 }; +} + +namespace ge { +#ifndef ONLY_COMPILE_OPEN_SRC +Status AssignRemovePass::Run(NodePtr &node) { + GELOGD("AssignRemovePass running"); + + if (TransformAttr(node) != SUCCESS) { + GELOGE(FAILED, "Transform assign_var_name attr failed, node=%s", node->GetName().c_str()); + return FAILED; + } + + if (node->GetType() == ASSIGN) { + if (OptimizedAssignNode(node) != SUCCESS) { + GELOGE(FAILED, "Optimize for assign_node %s failed", node->GetName().c_str()); + return FAILED; + } + } + + GELOGD("AssignRemovePass success"); + return SUCCESS; +} + +/// +/// @brief Optimize for assign_node +/// @param [in] assign_node +/// @return Status +/// +Status AssignRemovePass::OptimizedAssignNode(NodePtr &assign_node) { + const auto &ref_in_anchor = assign_node->GetInDataAnchor(kAssignRefInputIndex); + const auto &value_in_anchor = assign_node->GetInDataAnchor(kAssignValueInputIndex); + if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) { + GELOGE(FAILED, "In data anchor is null, node:%s", assign_node->GetName().c_str()); + return FAILED; + } + const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor(); + const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor(); + if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) { + GELOGE(FAILED, "Peer data anchor is null, node:%s", assign_node->GetName().c_str()); + return FAILED; + } + + if (IsCondMatch(assign_node, ref_peer_anchor, value_peer_anchor)) { + /// + /// variable not-const not-const + /// \ / | + /// \ / | + /// Assign ----> variable + /// | | + /// | | + /// node node + /// + GELOGD("Optimization for assign_node %s start", assign_node->GetName().c_str()); + if (IsolateAndDeleteNode(assign_node, {kAssignRefInputIndex}) != SUCCESS) { + GELOGE(FAILED, "Isolate and delete assign_node %s failed.", assign_node->GetName().c_str()); + return FAILED; + } + + const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); + const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); + if ((ref_input == nullptr) || (value_input == nullptr)) { + GELOGE(FAILED, "value input is null"); + return FAILED; + } + + // variable has and only has one input + if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str()); + return FAILED; + } + if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str()); + return FAILED; + } + + GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", + value_input->GetName().c_str(), ref_input->GetName().c_str()); + if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME, + ref_input->GetName())) { + GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); + return FAILED; + } + auto value_node = value_peer_anchor->GetOwnerNode(); + AddRePassNode(value_node); + } + return SUCCESS; +} + +/// +/// @brief Transform assign_var_name attr +/// @param [in] node +/// @return Status +/// +Status AssignRemovePass::TransformAttr(NodePtr &node) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDesc()) { + int32_t inplace_input_idx = -1; + std::string assign_var_name; + if (AttrUtils::GetInt(output_desc, INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx) && + AttrUtils::GetStr(output_desc, ASSIGN_VAR_NAME, assign_var_name)) { + GELOGD("Transform attr ASSIGN_VAR_NAME on node %s, assign_var_name=%s, inplace_input_idx=%d, ", + node->GetName().c_str(), assign_var_name.c_str(), inplace_input_idx); + const auto &in_data_anchor = node->GetInDataAnchor(inplace_input_idx); + GE_CHECK_NOTNULL(in_data_anchor); + const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_data_anchor); + auto in_node = peer_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(in_node->GetOpDesc()); + GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", in_node->GetName().c_str(), assign_var_name.c_str()); + if (!AttrUtils::SetStr(in_node->GetOpDesc()->MutableOutputDesc(peer_data_anchor->GetIdx()), + ASSIGN_VAR_NAME, assign_var_name)) { + GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); + return FAILED; + } + AddRePassNode(in_node); + } + } + return SUCCESS; +} +#else +Status AssignRemovePass::Run(NodePtr &node) { + GELOGD("AssignRemovePass running"); + if (node->GetType() != ASSIGN) { + GELOGD("No need run AssignRemovePass on [%s, %s].", node->GetName().c_str(), node->GetType().c_str()); + return SUCCESS; + } + + const auto &ref_in_anchor = node->GetInDataAnchor(kAssignRefInputIndex); + const auto &value_in_anchor = node->GetInDataAnchor(kAssignValueInputIndex); + if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) { + GELOGE(FAILED, "In data anchor is null, node:%s", node->GetName().c_str()); + return FAILED; + } + const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor(); + const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor(); + if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) { + GELOGE(FAILED, "Peer data anchor is null, node:%s", node->GetName().c_str()); + return FAILED; + } + + if (IsCondMatch(node, ref_peer_anchor, value_peer_anchor)) { + /// + /// variable not-const not-const + /// \ / | + /// \ / | + /// Assign ----> variable + /// | | + /// | | + /// node node + /// + GELOGI("Optimization for assign_node %s start", node->GetName().c_str()); + if (IsolateAndDeleteNode(node, {kAssignRefInputIndex}) != SUCCESS) { + GELOGE(FAILED, "Isolate and delete assign_node %s failed.", node->GetName().c_str()); + return FAILED; + } + AddNodeDeleted(node); + + const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); + const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); + if ((ref_input == nullptr) || (value_input == nullptr)) { + GELOGE(FAILED, "value input is null"); + return FAILED; + } + if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME, + ref_input->GetName())) { + GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); + return FAILED; + } + + // variable has and only has one input + if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str()); + return FAILED; + } + if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str()); + return FAILED; + } + } + + GELOGD("AssignRemovePass success"); + return SUCCESS; +} +#endif +/// +/// @brief Check if need optimize for assign_node +/// @param [in] assign_node +/// @param [in] peer_data_anchor for ref_input of assign_node +/// @param [in] peer_data_anchor for value_input of assign_node +/// @return Status +/// +bool AssignRemovePass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor, + const OutDataAnchorPtr &value_peer_anchor) { + GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s", + node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(), + value_peer_anchor->GetOwnerNode()->GetName().c_str()); + + if (kNoTaskNodeTypes.count(value_peer_anchor->GetOwnerNode()->GetType()) > 0) { + GELOGD("value input is not calculate node"); + return false; + } + + const std::string &ref_type = ref_peer_anchor->GetOwnerNode()->GetType(); + if ((ref_type != VARIABLE) && (ref_type != VARIABLEV2)) { + GELOGD("ref input is not var"); + return false; + } + if (!ref_peer_anchor->GetOwnerNode()->GetInDataNodes().empty()) { + GELOGD("ref input has data input"); + return false; + } + + if ((ref_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum) || + (value_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum)) { + GELOGD("ref / value input has other output(s)"); + return false; + } + + GELOGD("Optimization condition matches, assign_node: %s", node->GetName().c_str()); + return true; +} +} // namespace ge diff --git a/ge/graph/passes/assign_pass.h b/ge/graph/passes/assign_remove_pass.h similarity index 68% rename from ge/graph/passes/assign_pass.h rename to ge/graph/passes/assign_remove_pass.h index 11cf1073..f8ef2e13 100644 --- a/ge/graph/passes/assign_pass.h +++ b/ge/graph/passes/assign_remove_pass.h @@ -14,17 +14,32 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_ASSIGN_PASS_H_ -#define GE_GRAPH_PASSES_ASSIGN_PASS_H_ +#ifndef GE_GRAPH_PASSES_ASSIGN_REMOVE_PASS_H_ +#define GE_GRAPH_PASSES_ASSIGN_REMOVE_PASS_H_ #include "graph/passes/base_pass.h" namespace ge { -class AssignPass : public BaseNodePass { +class AssignRemovePass : public BaseNodePass { public: Status Run(NodePtr &node) override; private: +#ifndef ONLY_COMPILE_OPEN_SRC + /// + /// @brief Optimize for assign_node + /// @param [in] assign_node + /// @return Status + /// + Status OptimizedAssignNode(NodePtr &assign_node); + + /// + /// @brief Transform assign_var_name attr + /// @param [in] node + /// @return Status + /// + Status TransformAttr(NodePtr &node); +#endif /// /// @brief Check if need optimize for assign_node /// @param [in] assign_node @@ -36,4 +51,4 @@ class AssignPass : public BaseNodePass { const OutDataAnchorPtr &value_peer_anchor); }; } // namespace ge -#endif // GE_GRAPH_PASSES_ASSIGN_PASS_H_ +#endif // GE_GRAPH_PASSES_ASSIGN_REMOVE_PASS_H_ diff --git a/ge/graph/passes/constant_fuse_same_pass.cc b/ge/graph/passes/constant_fuse_same_pass.cc index d0970c59..8ee89648 100644 --- a/ge/graph/passes/constant_fuse_same_pass.cc +++ b/ge/graph/passes/constant_fuse_same_pass.cc @@ -19,13 +19,7 @@ #include #include #include -#include #include - -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -121,11 +115,21 @@ void ConstantFuseSamePass::GetFuseConstNodes(ComputeGraphPtr &graph, TypeUtils::DataTypeToSerialString(data_type).c_str()); continue; } +#ifndef ONLY_COMPILE_OPEN_SRC + if ((type_size != 0) && (weight->MutableData().GetAlignedPtr() == nullptr)) { + GELOGW("aligned_ptr is null while size is not 0"); + continue; + } +#endif ++insert_const_nums; SameConstKey map_key; map_key.data_size = type_size; +#ifndef ONLY_COMPILE_OPEN_SRC + map_key.aligned_ptr = weight->MutableData().GetAlignedPtr(); +#else map_key.data = weight->GetData().GetData(); +#endif map_key.data_type = data_type; map_key.format = output_tensor->GetFormat(); map_key.shape = output_tensor->GetShape().GetDims(); diff --git a/ge/graph/passes/constant_fuse_same_pass.h b/ge/graph/passes/constant_fuse_same_pass.h index 4935da84..ae39c707 100755 --- a/ge/graph/passes/constant_fuse_same_pass.h +++ b/ge/graph/passes/constant_fuse_same_pass.h @@ -21,14 +21,20 @@ #include #include #include - +#ifndef ONLY_COMPILE_OPEN_SRC +#include "graph/aligned_ptr.h" +#endif #include "graph/types.h" #include "inc/graph_pass.h" namespace ge { struct SameConstKey { int data_size; +#ifndef ONLY_COMPILE_OPEN_SRC + std::shared_ptr aligned_ptr; +#else const uint8_t *data; +#endif DataType data_type; Format format; std::vector shape; @@ -38,10 +44,19 @@ struct SameConstKey { if (data_size != key.data_size) { return data_size < key.data_size; } +#ifndef ONLY_COMPILE_OPEN_SRC + if (data_size != 0) { + int ret = memcmp(aligned_ptr->Get(), key.aligned_ptr->Get(), data_size); + if (ret != 0) { + return ret < 0; + } + } +#else int ret = memcmp(data, key.data, data_size); if (ret != 0) { return ret < 0; } +#endif if (data_type != key.data_type) { return data_type < key.data_type; } diff --git a/ge/graph/passes/inplace_support_check_pass.cc b/ge/graph/passes/inplace_support_check_pass.cc new file mode 100644 index 00000000..73cc7f3b --- /dev/null +++ b/ge/graph/passes/inplace_support_check_pass.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/inplace_support_check_pass.h" +#include "framework/common/debug/log.h" +#include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" + +namespace { +constexpr uint32_t kInplaceSupportOutputIndex = 0; +constexpr uint32_t kInplaceSupportOutputNum = 1; +static const std::set kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, + ge::CONSTANT, ge::CONSTANTOP, + ge::VARIABLE, ge::VARIABLEV2 }; +} + +namespace ge { +Status InplaceSupportCheckPass::Run(NodePtr &node) { + GELOGD("InplaceSupportCheckPass running"); + if (node->GetAllOutDataAnchorsSize() != kInplaceSupportOutputNum) { + GELOGD("output num of node %s is not %u, skip InplaceSupportCheckPass", + node->GetName().c_str(), kInplaceSupportOutputNum); + return SUCCESS; + } + GE_CHECK_NOTNULL(node->GetOpDesc()); + const DataType &output_type = node->GetOpDesc()->GetOutputDesc(kInplaceSupportOutputIndex).GetDataType(); + const GeShape &output_shape = node->GetOpDesc()->GetOutputDesc(kInplaceSupportOutputIndex).GetShape(); + GELOGD("process InplaceSupportCheckPass on node %s", node->GetName().c_str()); + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_data_anchor == nullptr) { + continue; + } + auto in_node = peer_data_anchor->GetOwnerNode(); + if (kSrcNodeTypes.count(in_node->GetType()) > 0) { + GELOGD("meet src_node %s", in_node->GetName().c_str()); + continue; + } + if (peer_data_anchor->GetPeerInDataNodesSize() != kInplaceSupportOutputNum) { + GELOGD("peer_data_anchor links with multi in_data_anchors"); + continue; + } + + int32_t inplace_input_idx = in_data_anchor->GetIdx(); + const DataType &input_type = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetDataType(); + const GeShape &input_shape = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetShape(); + if (input_type != output_type) { + GELOGW("DataType mismatch, in_idx=%d, input_type=%u, output_type=%u", inplace_input_idx, input_type, output_type); + continue; + } + if (input_shape.GetDims() != output_shape.GetDims()) { + GELOGW("Shape mismatch, in_idx=%d, input_shape=[%s], output_shape=[%s]", + inplace_input_idx, input_shape.ToString().c_str(), output_shape.ToString().c_str()); + continue; + } + + GELOGD("add attr INPLACE_SUPPORT_INPUT_INDEX on node %s, input_idx=%d", node->GetName().c_str(), inplace_input_idx); + if (!AttrUtils::SetInt(node->GetOpDesc()->MutableOutputDesc(kInplaceSupportOutputIndex), + INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx)) { + GELOGE(FAILED, "Set attr INPLACE_SUPPORT_INPUT_INDEX on node %s failed.", node->GetName().c_str()); + return FAILED; + } + AddRePassNode(node); + break; + } + + GELOGD("InplaceSupportCheckPass success"); + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/inplace_support_check_pass.h b/ge/graph/passes/inplace_support_check_pass.h new file mode 100644 index 00000000..be2d6c75 --- /dev/null +++ b/ge/graph/passes/inplace_support_check_pass.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INPLACE_SUPPORT_CHECK_PASS_H_ +#define GE_GRAPH_PASSES_INPLACE_SUPPORT_CHECK_PASS_H_ + +#include "graph/passes/base_pass.h" + +namespace ge { +class InplaceSupportCheckPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INPLACE_SUPPORT_CHECK_PASS_H_ diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index a7b922e0..392968e7 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -598,7 +598,7 @@ Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, cons /// Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, const std::set &same_cond_switch) { - GELOGD("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), + GELOGD("ModifySwitchInCtlEdges: switch_node=%s, cast_node=%s", switch_node->GetName().c_str(), cast_node->GetName().c_str()); std::string orig_switch_name = switch_node->GetName(); OpDescPtr switch_desc = switch_node->GetOpDesc(); diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index fd70aee9..32f877cf 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -19,7 +19,6 @@ #include #include #include "common/formats/format_transfers/format_transfer_fractal_nz.h" -#include "common/formats/format_transfers/format_transfer_fractal_z.h" #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_transpose.h" @@ -38,7 +37,9 @@ #include "graph/passes/addn_pass.h" #include "graph/passes/aicpu_constant_folding_pass.h" #include "graph/passes/assert_pass.h" -#include "graph/passes/assign_pass.h" +#ifdef ONLY_COMPILE_OPEN_SRC +#include "graph/passes/assign_remove_pass.h" +#endif #include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/cond_pass.h" #include "graph/passes/cond_remove_pass.h" @@ -1698,7 +1699,9 @@ Status GraphPrepare::PrepareOptimize() { VarIsInitializedOpPass var_is_initialized_pass; ParallelConcatStartOpPass parallel_concat_start_op_pass; IdentityPass identity_pass(false); - AssignPass assign_pass; +#ifdef ONLY_COMPILE_OPEN_SRC + AssignRemovePass assign_remove_pass; +#endif SnapshotPass snapshot_pass; if (!options_.train_graph_flag) { names_to_passes.emplace_back("DropOutPass", &dropout_pass); @@ -1713,9 +1716,11 @@ Status GraphPrepare::PrepareOptimize() { names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); names_to_passes.emplace_back("IdentityPass", &identity_pass); +#ifdef ONLY_COMPILE_OPEN_SRC if (GetContext().GetHostExecFlag()) { - names_to_passes.emplace_back("AssignPass", &assign_pass); + names_to_passes.emplace_back("AssignRemovePass", &assign_remove_pass); } +#endif GE_TIMESTAMP_START(names_to_passes); ret = ge_passes.Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); diff --git a/ge/hybrid/common/npu_memory_allocator.cc b/ge/hybrid/common/npu_memory_allocator.cc index 2c38367a..c2602f37 100644 --- a/ge/hybrid/common/npu_memory_allocator.cc +++ b/ge/hybrid/common/npu_memory_allocator.cc @@ -20,6 +20,9 @@ #include "graph/manager/graph_caching_allocator.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/rdma_pool_allocator.h" +#ifndef ONLY_COMPILE_OPEN_SRC +#include "graph/manager/host_mem_allocator.h" +#endif namespace ge { namespace hybrid { @@ -64,7 +67,11 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { if (mem_type == RDMA_HBM) { buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(allocate_size, device_id_); } else if (mem_type == HOST_DDR) { +#ifndef ONLY_COMPILE_OPEN_SRC + buffer = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(allocate_size); +#else buffer = malloc(allocate_size); +#endif } else { if (allocate_size > kMaxHbmMemorySize) { GELOGE(PARAM_INVALID, "Invalid HBM memory size: %zu", allocate_size); @@ -101,7 +108,11 @@ void NpuMemoryAllocator::Deallocate(void *data, MemStorageType mem_type) { if (mem_type == RDMA_HBM) { MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Free(reinterpret_cast(data), device_id_); } else if (mem_type == HOST_DDR) { +#ifndef ONLY_COMPILE_OPEN_SRC + MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Free(data); +#else free(data); +#endif } else { MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(reinterpret_cast(data), device_id_); } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 46f8d452..46c9c39b 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -25,11 +25,13 @@ #include "graph/manager/graph_var_manager.h" #include "graph/manager/host_mem_manager.h" #include "graph/manager/trans_var_data_utils.h" +#ifndef ONLY_COMPILE_OPEN_SRC +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/host_mem_allocator.h" +#endif #include "graph/utils/graph_utils.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/node_executor/node_executor.h" -#include "framework/common/debug/ge_log.h" -#include "graph/utils/attr_utils.h" namespace ge { namespace hybrid { @@ -852,9 +854,24 @@ Status HybridModelBuilder::InitConstantOps() { std::unique_ptr var_tensor; if (GetContext().GetHostExecFlag()) { +#ifndef ONLY_COMPILE_OPEN_SRC + GE_CHECK_NOTNULL(ge_tensor); + // Address for eigen kernel should be aligned with 16 bytes + // Tensors return by api GetWeights share data with proto, whose addr is not confirmed to be aligned + GeTensor aligned_tensor = ge_tensor->Clone(); + GELOGD("Init tensor with host constant %s size = %zu", var_name.c_str(), aligned_tensor.MutableData().GetSize()); + if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(aligned_tensor.GetAlignedPtr(), + aligned_tensor.GetData().size()) == nullptr) { + GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); + return MEMALLOC_FAILED; + } + var_tensor.reset(new(std::nothrow)TensorValue(aligned_tensor.MutableData().data(), + aligned_tensor.GetData().size())); +#else auto buffer = ge_tensor->MutableData(); GELOGD("Init tensor with host constant. size = %zu", buffer.GetSize()); var_tensor.reset(new(std::nothrow)TensorValue(buffer.GetData(), buffer.GetSize())); +#endif } else { GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); @@ -909,9 +926,21 @@ Status HybridModelBuilder::InitVariableTensors() { GELOGE(GE_GRAPH_MALLOC_FAILED, "Host variable [%s] malloc failed.", it.first.c_str()); return GE_GRAPH_MALLOC_FAILED; } +#ifndef ONLY_COMPILE_OPEN_SRC + if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(mem_info.host_aligned_ptr, + tensor_size) == nullptr) { + GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); + return MEMALLOC_FAILED; + } + GELOGD("Host variable [%s] malloc success, size=%lld.", it.first.c_str(), tensor_size); + + std::unique_ptr tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(), + tensor_size)); +#else GELOGD("Host variable [%s] malloc success.", it.first.c_str()); std::unique_ptr tensor(new (std::nothrow) TensorValue(mem_info.host_address, tensor_size)); +#endif GE_CHECK_NOTNULL(tensor); hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor)); } diff --git a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc index a61195b0..32522fe8 100755 --- a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc +++ b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc @@ -18,6 +18,10 @@ #include "hybrid/node_executor/host_cpu/kernel_factory.h" #include "graph/passes/folding_pass.h" #include "hybrid/model/hybrid_model.h" +#ifndef ONLY_COMPILE_OPEN_SRC +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/host_mem_allocator.h" +#endif #include "ge_local_engine/engine/host_cpu_engine.h" namespace ge { @@ -50,15 +54,23 @@ Status CpuKernelNodeTask::Execute(TaskContext &context) { auto input_desc_ptr = context.GetInputDesc(i); GE_CHECK_NOTNULL(input_desc_ptr); const auto &input_desc = *input_desc_ptr; +#ifndef ONLY_COMPILE_OPEN_SRC + auto tensor = context.GetInput(i); + GE_CHECK_NOTNULL(tensor); + auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); + GE_CHECK_NOTNULL(item.second); + auto in_tensor = MakeShared(input_desc, item.second, item.first); +#else GE_CHECK_NOTNULL(context.GetInput(i)); auto in_tensor = MakeShared(input_desc, reinterpret_cast(context.GetInput(i)->GetData()), context.GetInput(i)->GetSize()); +#endif GE_CHECK_NOTNULL(in_tensor); in_tensor->MutableTensorDesc().SetDataType(input_desc.GetDataType()); in_tensor->MutableTensorDesc().SetShape(input_desc.GetShape()); inputs.emplace_back(in_tensor); - GELOGI("node:%s allocate input %d, size=%zu", op_desc->GetName().c_str(), i, in_tensor->GetData().size()); + GELOGD("node:%s allocate input %d, size=%zu", op_desc->GetName().c_str(), i, in_tensor->GetData().size()); } std::vector outputs; @@ -72,14 +84,20 @@ Status CpuKernelNodeTask::Execute(TaskContext &context) { } auto tensor = context.GetOutput(i); GE_CHECK_NOTNULL(tensor); +#ifndef ONLY_COMPILE_OPEN_SRC + auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); + GE_CHECK_NOTNULL(item.second); + auto out_tensor = MakeShared(output_desc, item.second, item.first); +#else auto out_tensor = MakeShared(output_desc, reinterpret_cast(tensor->GetData()), tensor->GetSize()); +#endif GE_CHECK_NOTNULL(out_tensor); out_tensor->MutableTensorDesc().SetDataType(output_desc.GetDataType()); out_tensor->MutableTensorDesc().SetShape(output_desc.GetShape()); outputs.emplace_back(out_tensor); - GELOGI("node:%s allocate output %d, size=%zu", op_desc->GetName().c_str(), i, out_tensor->GetData().size()); + GELOGD("node:%s allocate output %d, size=%zu", op_desc->GetName().c_str(), i, out_tensor->GetData().size()); } return HostCpuEngine::GetInstance().Run(node_, inputs, outputs); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 8ccb6180..175774bb 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -224,7 +224,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/cond_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/for_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/enter_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/assign_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/assign_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc"