|
|
|
@ -27,6 +27,34 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
|
|
|
|
std::vector<AnfNodePtr> *plant_inputs) {
|
|
|
|
|
if (!AnfAlgo::IsTupleOutput(tuple_input)) {
|
|
|
|
|
auto abs = tuple_input->abstract();
|
|
|
|
|
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(plant_inputs);
|
|
|
|
|
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
|
|
|
|
|
if (tuple_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
|
|
|
|
auto make_tuple = tuple_input->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
|
|
|
|
for (size_t j = 0; j < tuple_input_num; ++j) {
|
|
|
|
|
// using for graph kernel
|
|
|
|
|
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
|
|
|
|
plant_inputs->emplace_back(dyn_input_node);
|
|
|
|
|
}
|
|
|
|
|
return input_size;
|
|
|
|
|
}
|
|
|
|
|
for (size_t index = 0; index < input_size; ++index) {
|
|
|
|
|
auto dyn_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
|
|
|
|
|
plant_inputs->emplace_back(dyn_input_node);
|
|
|
|
|
}
|
|
|
|
|
return input_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
@ -41,25 +69,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
|
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
|
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
|
|
|
|
auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
|
|
|
|
|
dyn_input_sizes.push_back(input_size);
|
|
|
|
|
auto make_tuple = input_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
|
|
|
|
for (size_t j = 0; j < tuple_input_num; ++j) {
|
|
|
|
|
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
|
|
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
|
|
|
|
|
if (!success) {
|
|
|
|
|
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
plant_inputs.push_back(dyn_input_node);
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::IsTupleOutput(input_node)) {
|
|
|
|
|
dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
|
|
|
|
} else {
|
|
|
|
|
dyn_input_sizes.push_back(-1);
|
|
|
|
|
plant_inputs.push_back(input_node);
|
|
|
|
|