|
|
|
@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr;
|
|
|
|
|
using kernel::KernelMod;
|
|
|
|
|
using kernel::KernelModPtr;
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kNopNodeInputSize = 2;
|
|
|
|
|
constexpr size_t kNopNodeRealInputIndex = 1;
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shape);
|
|
|
|
|
std::vector<size_t> shape_size_t;
|
|
|
|
@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
|
|
|
|
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
|
|
|
}
|
|
|
|
|
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
|
|
|
|
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
|
|
|
}
|
|
|
|
|
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_index_value_node);
|
|
|
|
|
auto value_node = output_index_value_node->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
return IntToSize(GetValue<int>(value_node->value()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
if (anf_node->isa<ValueNode>()) {
|
|
|
|
@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
|
|
|
|
|
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index,
|
|
|
|
|
bool visit_nop_node,
|
|
|
|
|
const std::vector<PrimitivePtr> &return_types) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
for (const auto &prim_type : return_types) {
|
|
|
|
|
if (CheckPrimitiveType(anf_node, prim_type)) {
|
|
|
|
|
return std::make_pair(anf_node, index);
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
|
|
|
|
|
return CheckPrimitiveType(anf_node, prim_type);
|
|
|
|
|
})) {
|
|
|
|
|
return KernelWithIndex(anf_node, index);
|
|
|
|
|
}
|
|
|
|
|
if (anf_node->isa<ValueNode>()) {
|
|
|
|
|
return std::make_pair(anf_node, 0);
|
|
|
|
|
} else if (anf_node->isa<Parameter>()) {
|
|
|
|
|
return std::make_pair(anf_node, 0);
|
|
|
|
|
} else if (anf_node->isa<CNode>()) {
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto input0 = cnode->input(0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input0);
|
|
|
|
|
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
|
|
|
|
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
|
|
|
}
|
|
|
|
|
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input2);
|
|
|
|
|
auto value_node = input2->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
int item_idx = GetValue<int>(value_node->value());
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
|
|
|
|
|
visit_nop_node, return_types);
|
|
|
|
|
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
|
|
|
|
|
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
|
|
|
|
|
if (cnode->inputs().size() == 2) {
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
|
|
|
|
|
if (!anf_node->isa<CNode>()) {
|
|
|
|
|
return KernelWithIndex(anf_node, 0);
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
|
|
|
|
|
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
|
|
|
|
|
GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types);
|
|
|
|
|
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
|
|
|
|
|
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
|
|
|
|
|
size_t make_tuple_input_index = item_with_index_tmp.second + 1;
|
|
|
|
|
if (make_tuple_input_index >= make_tuple_inputs.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
|
|
|
|
|
<< "].";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
return std::make_pair(anf_node, index);
|
|
|
|
|
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The input is invalid";
|
|
|
|
|
return item_with_index_tmp;
|
|
|
|
|
}
|
|
|
|
|
if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
|
|
|
|
|
}
|
|
|
|
|
if (opt::IsNopNode(cnode) && visit_nop_node) {
|
|
|
|
|
if (cnode->size() != kNopNodeInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString();
|
|
|
|
|
}
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types);
|
|
|
|
|
}
|
|
|
|
|
return KernelWithIndex(anf_node, index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
|
|
|
|
@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
|
|
|
|
|
if (opt::IsNopNode(node) && visit_nop_node) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->inputs().size() == 2) {
|
|
|
|
|
if (cnode->size() == kNopNodeInputSize) {
|
|
|
|
|
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
|
|
|
|
@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|
|
|
|
if (opt::IsNopNode(node) && visit_nop_node) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->inputs().size() == 2) {
|
|
|
|
|
if (cnode->inputs().size() == kNopNodeInputSize) {
|
|
|
|
|
return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
|
|
|
|
@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
|
|
|
|
|
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
|
|
|
|
|
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
|
|
|
|
|
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
|
|
|
|
|
IsPrimitive(input, prim::kPrimReturn);
|
|
|
|
|
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
|
|
|
|
|
return !is_virtual_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
|
|
|
|
|
}
|
|
|
|
|
return GetCNodeOutputPrecision(kernel_with_index.first);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->inputs().empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Illegal null input of cnode.";
|
|
|
|
|
}
|
|
|
|
|
auto input = node->input(kAnfPrimitiveIndex);
|
|
|
|
|
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
|
|
|
|
|
}
|
|
|
|
|
} // namespace session
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|