!2857 Revert "optimize the graph output of all nop node"

Merge pull request !2857 from limingqi107/master
pull/2857/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d0dd892884

@ -30,7 +30,6 @@
#include "kernel/common_utils.h" #include "kernel/common_utils.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "ir/value.h" #include "ir/value.h"
#include "pre_activate/common/helper.h"
using mindspore::kernel::Address; using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
@ -637,7 +636,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
} }
} }
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel, void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
AddressPtrList *kernel_outputs) { AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
@ -649,15 +648,9 @@ void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const minds
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs); return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
} }
auto is_all_nop_node = opt::IsAllNopNode(&graph);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
DeviceAddressPtr device_address; auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
if (is_all_nop_node) {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
} else {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
}
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>(); kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
@ -667,16 +660,8 @@ void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const minds
kernel_inputs->emplace_back(input); kernel_inputs->emplace_back(input);
} }
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
MS_EXCEPTION_IF_NULL(kernel_mod); auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
} else {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr output = std::make_shared<kernel::Address>(); kernel::AddressPtr output = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
output->addr = device_address->ptr_; output->addr = device_address->ptr_;
@ -685,7 +670,7 @@ void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const minds
kernel_outputs->emplace_back(output); kernel_outputs->emplace_back(output);
} }
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace); MS_EXCEPTION_IF_NULL(workspace);
@ -740,7 +725,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_inputs; AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces; AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs; AddressPtrList kernel_outputs;
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed."; MS_LOG(ERROR) << "Launch kernel failed.";

@ -96,8 +96,8 @@ class KernelRuntime {
private: private:
void AssignStaticMemoryOutput(session::KernelGraph *graph); void AssignStaticMemoryOutput(session::KernelGraph *graph);
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
bool LaunchKernelMod(const session::KernelGraph &graph); bool LaunchKernelMod(const session::KernelGraph &graph);
void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);

@ -81,15 +81,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
} }
} }
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
DeviceAddressPtr address; auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
auto is_all_nop_node = opt::IsAllNopNode(&graph);
if (is_all_nop_node) {
// The graph does not remove the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
} else {
// The graph removes the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
}
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, output_index); auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
TypeId type_id = kNumberTypeFloat32; TypeId type_id = kNumberTypeFloat32;

Loading…
Cancel
Save