fix reallocate memory bug for communication op

pull/14492/head
laiyongqiang 4 years ago
parent adc24f4263
commit 4526ce6845

@ -884,6 +884,14 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
return kernel_info->OutputAddrExist(output_idx); return kernel_info->OutputAddrExist(output_idx);
} }
bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->WorkspaceAddrExist(output_idx);
}
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
bool visit_nop_node) { bool visit_nop_node) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);

@ -153,6 +153,8 @@ class AnfRuntimeAlgorithm {
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
// check whether output addr is exist or not // check whether output addr is exist or not
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
// check whether workspace addr is exist or not
static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
// get address from prev node,input_index is the input index of current node related to prev node // get address from prev node,input_index is the input index of current node related to prev node
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
bool visit_nop_node = true); bool visit_nop_node = true);

@ -81,6 +81,13 @@ DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const {
return workspace_address_list_[index]; return workspace_address_list_[index];
} }
bool KernelInfo::WorkspaceAddrExist(size_t index) const {
if (index >= workspace_address_list_.size()) {
return false;
}
return workspace_address_list_[index] != nullptr;
}
bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
if (workspace_address_list_.empty()) { if (workspace_address_list_.empty()) {
// parameter and valuenode // parameter and valuenode

@ -55,6 +55,7 @@ class KernelInfo : public KernelInfoDevice {
bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index);
DeviceAddress *GetWorkspaceAddr(size_t index) const; DeviceAddress *GetWorkspaceAddr(size_t index) const;
DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const;
bool WorkspaceAddrExist(size_t index) const;
bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index);
void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); void set_kernel_mod(const kernel::KernelModPtr &kernel_mod);
kernel::KernelMod *MutableKernelMod() const; kernel::KernelMod *MutableKernelMod() const;

@ -454,8 +454,8 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
std::vector<size_t> align_size_list; std::vector<size_t> align_size_list;
for (uint64_t mem_size : output_sizes) { for (uint64_t mem_size : output_sizes) {
if (AnfAlgo::OutputAddrExist(node, output_index++)) { if (AnfAlgo::OutputAddrExist(node, output_index++)) {
MS_LOG(INFO) << "communication op addr exist"; MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
continue; return;
} }
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size); mem_size = mem_manager_->GetCommonAlignSize(mem_size);
@ -464,6 +464,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
align_size_list.emplace_back(mem_size); align_size_list.emplace_back(mem_size);
} }
if (align_size_list.empty()) {
return;
}
if (type == kReuseDynamicMem) { if (type == kReuseDynamicMem) {
// reuse communication op's all outputs' memory // reuse communication op's all outputs' memory
type = kReuseDynamicCommMem; type = kReuseDynamicCommMem;
@ -533,6 +537,10 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
auto input_node = input_node_with_index.first; auto input_node = input_node_with_index.first;
if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
return;
}
DeviceAddressPtr address = nullptr; DeviceAddressPtr address = nullptr;
if (input_node->isa<CNode>()) { if (input_node->isa<CNode>()) {
address = PreAssignCNodeMemory(input_node, input_node_with_index.second); address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
@ -811,6 +819,10 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
size_t index = 0; size_t index = 0;
for (auto &size : kernel_mod->GetWorkspaceSizeList()) { for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
if (AnfAlgo::WorkspaceAddrExist(node, index)) {
MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
return;
}
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
index++; index++;

Loading…
Cancel
Save