|
|
|
@ -338,7 +338,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
|
|
|
|
|
if (!AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -360,7 +360,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
|
|
|
|
|
if (!IsRealKernel(node)) {
|
|
|
|
|
GetPrevNodeOutputFormat(node, input_idx);
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -467,7 +467,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
|
|
|
|
|
if (!IsRealKernel(node)) {
|
|
|
|
|
return GetPrevNodeOutputReshapeType(node, input_idx);
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -486,7 +486,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
|
|
|
|
|
if (!IsRealKernel(node)) {
|
|
|
|
|
return GetPrevNodeOutputReshapeType(node, output_idx);
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -546,7 +546,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
|
|
|
|
|
if (!IsRealKernel(node)) {
|
|
|
|
|
return GetPrevNodeOutputDeviceDataType(node, output_idx);
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -567,7 +567,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
|
|
|
|
|
if (!IsRealKernel(node)) {
|
|
|
|
|
return GetPrevNodeOutputDeviceDataType(node, 0);
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -597,7 +597,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
|
|
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto addr = kernel_info->GetOutputAddr(output_idx);
|
|
|
|
|
if (addr == nullptr) {
|
|
|
|
@ -619,7 +619,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
|
|
|
|
|
if (addr == nullptr) {
|
|
|
|
@ -636,7 +636,7 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
|
|
|
|
|
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
|
|
|
|
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->OutputAddrExist(output_idx);
|
|
|
|
|
}
|
|
|
|
@ -656,7 +656,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode
|
|
|
|
|
// set output device addr of anf_node
|
|
|
|
|
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
if (!kernel_info->SetOutputAddr(addr, output_idx)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
|
|
|
|
@ -666,7 +666,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out
|
|
|
|
|
// set workspace device addr of anf_node
|
|
|
|
|
void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
|
|
|
|
@ -676,7 +676,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t
|
|
|
|
|
// get workspace device addr of anf_node
|
|
|
|
|
DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto addr = kernel_info->GetWorkspaceAddr(output_idx);
|
|
|
|
|
if (addr == nullptr) {
|
|
|
|
@ -720,7 +720,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_
|
|
|
|
|
|
|
|
|
|
kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
// select_kernel_build_info() has checked whether return pointer is null
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
@ -731,7 +731,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
|
|
|
|
|
// get KernelBuildType of node, such as ATT,RT,FWK and so on
|
|
|
|
|
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
// select_kernel_build_info() has checked whether return pointer is null
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
@ -741,7 +741,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
|
|
|
|
|
|
|
|
|
|
kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -750,7 +750,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
|
|
|
|
|
|
|
|
|
|
kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto build_info = kernel_info->select_kernel_build_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
@ -760,7 +760,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
|
|
|
|
|
// set select kernel_build_info
|
|
|
|
|
void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
|
|
|
|
|
}
|
|
|
|
@ -768,7 +768,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel
|
|
|
|
|
// get select kernel_build_info
|
|
|
|
|
KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->GetMutableSelectKernelBuildInfo();
|
|
|
|
|
}
|
|
|
|
@ -776,7 +776,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt
|
|
|
|
|
// get kernelMode
|
|
|
|
|
KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->MutableKernelMod();
|
|
|
|
|
}
|
|
|
|
@ -784,7 +784,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
|
|
|
|
|
// set kernel mod
|
|
|
|
|
void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
kernel_info->set_kernel_mod(kernel_mod);
|
|
|
|
|
}
|
|
|
|
@ -850,42 +850,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
|
|
|
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
kernel_info->set_stream_id(stream_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->stream_id();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
kernel_info->set_stream_distinction_label(stream_label);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->stream_distinction_label();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
kernel_info->set_graph_id(graph_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->graph_id();
|
|
|
|
|
}
|
|
|
|
@ -913,7 +913,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
|
|
|
|
|
if (node->isa<ValueNode>()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
return kernel_info->is_feature_map();
|
|
|
|
|
}
|
|
|
|
|