|
|
@ -30,7 +30,6 @@
|
|
|
|
#include "kernel/common_utils.h"
|
|
|
|
#include "kernel/common_utils.h"
|
|
|
|
#include "kernel/oplib/oplib.h"
|
|
|
|
#include "kernel/oplib/oplib.h"
|
|
|
|
#include "ir/value.h"
|
|
|
|
#include "ir/value.h"
|
|
|
|
#include "pre_activate/common/helper.h"
|
|
|
|
|
|
|
|
using mindspore::kernel::Address;
|
|
|
|
using mindspore::kernel::Address;
|
|
|
|
using mindspore::kernel::AddressPtr;
|
|
|
|
using mindspore::kernel::AddressPtr;
|
|
|
|
|
|
|
|
|
|
|
@ -637,7 +636,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
|
|
|
|
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
|
|
|
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
|
|
|
|
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
|
|
|
|
AddressPtrList *kernel_outputs) {
|
|
|
|
AddressPtrList *kernel_outputs) {
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel);
|
|
|
@ -649,15 +648,9 @@ void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const minds
|
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
|
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
|
|
|
|
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
|
|
|
|
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto is_all_nop_node = opt::IsAllNopNode(&graph);
|
|
|
|
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
|
|
|
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
|
|
|
|
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
|
|
|
|
DeviceAddressPtr device_address;
|
|
|
|
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
|
|
|
|
if (is_all_nop_node) {
|
|
|
|
|
|
|
|
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
|
|
|
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
@ -667,16 +660,8 @@ void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const minds
|
|
|
|
kernel_inputs->emplace_back(input);
|
|
|
|
kernel_inputs->emplace_back(input);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
|
|
|
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
|
|
|
auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
|
|
|
|
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
|
|
|
|
|
|
|
DeviceAddressPtr device_address;
|
|
|
|
|
|
|
|
if (is_all_nop_node) {
|
|
|
|
|
|
|
|
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
|
|
|
|
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
|
|
|
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
output->addr = device_address->ptr_;
|
|
|
|
output->addr = device_address->ptr_;
|
|
|
@ -685,7 +670,7 @@ void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const minds
|
|
|
|
kernel_outputs->emplace_back(output);
|
|
|
|
kernel_outputs->emplace_back(output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
|
|
|
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
|
|
|
|
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
|
|
|
|
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
|
|
|
|
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
|
|
|
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
|
|
|
MS_EXCEPTION_IF_NULL(workspace);
|
|
|
|
MS_EXCEPTION_IF_NULL(workspace);
|
|
|
@ -740,7 +725,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|
|
|
AddressPtrList kernel_inputs;
|
|
|
|
AddressPtrList kernel_inputs;
|
|
|
|
AddressPtrList kernel_workspaces;
|
|
|
|
AddressPtrList kernel_workspaces;
|
|
|
|
AddressPtrList kernel_outputs;
|
|
|
|
AddressPtrList kernel_outputs;
|
|
|
|
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
|
|
|
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
|
|
|
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
|
|
|
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
|
|
|
if (!ret) {
|
|
|
|
if (!ret) {
|
|
|
|
MS_LOG(ERROR) << "Launch kernel failed.";
|
|
|
|
MS_LOG(ERROR) << "Launch kernel failed.";
|
|
|
|