|
|
|
@ -112,6 +112,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
|
|
|
|
|
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0);
|
|
|
|
|
} else if (opt::IsNopNode(cnode)) {
|
|
|
|
|
if (cnode->inputs().size() == 2) {
|
|
|
|
|
return VisitKernelWithReturnType(cnode->input(1), 0);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
return std::make_pair(anf_node, index);
|
|
|
|
|
}
|
|
|
|
@ -299,20 +305,23 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
|
|
|
|
|
return build_info->GetInputFormat(input_idx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
|
|
|
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
if (!anf_node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
|
|
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
|
|
|
|
|
<< ".";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
|
|
|
|
}
|
|
|
|
|
auto node = cnode->input(input_idx + 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
|
|
|
|
return VisitKernel(node, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
|
|
|
|
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -346,18 +355,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
|
|
|
|
|
<< ".";
|
|
|
|
|
}
|
|
|
|
|
auto input_node = cnode->input(input_idx + 1);
|
|
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
|
|
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
|
|
|
|
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -459,17 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "is not a CNode";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
|
|
|
|
}
|
|
|
|
|
auto input_node = cnode->input(input_idx + 1);
|
|
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
|
|
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
|
|
|
|
return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -492,17 +480,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
|
|
|
if (!anf_node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
|
|
|
|
}
|
|
|
|
|
auto node = cnode->input(input_idx + 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
|
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
|
|
|
|
return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -558,32 +536,12 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
|
|
|
if (!anf_node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf node is not a CNode";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
|
|
|
|
}
|
|
|
|
|
auto node = cnode->input(input_idx + 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
|
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
|
|
|
|
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
|
|
|
if (!anf_node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
|
|
|
|
}
|
|
|
|
|
auto node = cnode->input(input_idx + 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
|
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
|
|
|
|
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|