|
|
|
@ -65,7 +65,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
|
|
|
|
|
return VisitKernel(node, 0);
|
|
|
|
|
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
|
|
|
|
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "the node tuple_get_item must have 2 inputs!";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
|
|
|
}
|
|
|
|
|
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input2);
|
|
|
|
@ -102,7 +102,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input0);
|
|
|
|
|
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
|
|
|
|
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "the node tuple_get_item must have 2 inputs!";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
|
|
|
}
|
|
|
|
|
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input2);
|
|
|
|
@ -188,7 +188,7 @@ std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
|
|
|
|
|
void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only cnode has attr,but this anf is " << node->DebugString();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
@ -204,7 +204,7 @@ void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::st
|
|
|
|
|
MS_EXCEPTION_IF_NULL(from);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(to);
|
|
|
|
|
if (!from->isa<CNode>() || !to->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only cnode has attr,but this from_anf is " << from->DebugString() << " ,to_node is "
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
|
|
|
|
|
<< to->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
|
|
|
|
@ -218,7 +218,7 @@ void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr
|
|
|
|
|
MS_EXCEPTION_IF_NULL(from);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(to);
|
|
|
|
|
if (!from->isa<CNode>() || !to->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only cnode has attr,but this from_anf is " << from->DebugString() << ",to_node is "
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
|
|
|
|
|
<< from->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
|
|
|
|
@ -231,7 +231,7 @@ void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr
|
|
|
|
|
void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only cnode has attr,but this anf is " << node->DebugString();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
@ -241,7 +241,7 @@ void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr
|
|
|
|
|
bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(WARNING) << "only cnode has attr,but this anf is " << node->DebugString();
|
|
|
|
|
MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
@ -252,7 +252,7 @@ bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const AnfNodePtr &
|
|
|
|
|
size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only cnode has real input,but this anf is " << node->DebugString();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
@ -404,7 +404,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
|
std::vector<kernel::Axis> result;
|
|
|
|
|
if (!build_info->GetInputReshapeType(input_idx, &result)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "filed to ge the node's[ " << node->DebugString() << "] reshape type !";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !";
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
@ -417,7 +417,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_info);
|
|
|
|
|
std::vector<kernel::Axis> result;
|
|
|
|
|
if (!build_info->GetOutputReshapeType(output_idx, &result)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "filed to ge the node's[ " << node->DebugString() << "] reshape type !";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !";
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
@ -593,7 +593,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
if (!kernel_info->SetOutputAddr(addr, output_idx)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node " << node->DebugString() << "set adr" << output_idx << " fail";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -603,7 +603,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node " << node->DebugString() << "set adr" << output_idx << " fail";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -614,7 +614,7 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
auto addr = kernel_info->GetWorkspaceAddr(output_idx);
|
|
|
|
|
if (addr == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "output_idx " << output_idx << " of node " << node->DebugString()
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
|
|
|
|
|
<< "] workspace addr is not exist";
|
|
|
|
|
}
|
|
|
|
|
return addr;
|
|
|
|
@ -625,7 +625,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &
|
|
|
|
|
const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (types.size() != shapes.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "types size " << types.size() << "should be same with shapes size " << shapes.size();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
|
|
|
|
|
}
|
|
|
|
|
if (shapes.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
|
|
|
|
@ -636,7 +636,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &
|
|
|
|
|
auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
|
|
|
|
|
node->set_abstract(abstract);
|
|
|
|
|
} else {
|
|
|
|
|
// mutiple output handle
|
|
|
|
|
// multiple output handle
|
|
|
|
|
std::vector<AbstractBasePtr> abstract_list;
|
|
|
|
|
for (size_t i = 0; i < types.size(); ++i) {
|
|
|
|
|
std::vector<int> shape_int;
|
|
|
|
@ -647,12 +647,12 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &
|
|
|
|
|
node->set_abstract(abstract_tuple);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// copy a abstract of a node to another node
|
|
|
|
|
// copy an abstract of a node to another node
|
|
|
|
|
void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
|
|
|
|
|
to_node->set_abstract(from_node->abstract());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// get KernelBuildType of node ,such as ATT,RT,FWK and so on
|
|
|
|
|
// get KernelBuildType of node, such as ATT,RT,FWK and so on
|
|
|
|
|
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto kernel_info = node->kernel_info();
|
|
|
|
@ -846,7 +846,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
|
|
|
|
|
auto find = spec_node_list.find(node_name);
|
|
|
|
|
if (find != spec_node_list.end()) {
|
|
|
|
|
ret = find->second[cur_index];
|
|
|
|
|
MS_LOG(INFO) << "real input index change to" << ret << ", node name:" << node_name;
|
|
|
|
|
MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|