|
|
|
@ -27,86 +27,33 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
bool MakeValueNode(const AnfNodePtr &node) {
|
|
|
|
|
auto value_node = node->cast<ValueNodePtr>();
|
|
|
|
|
if (value_node == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create kernel_info fo new value node
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
|
value_node->set_kernel_info(kernel_info);
|
|
|
|
|
// create kernel_build_info for new value node
|
|
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
// set the format of value_node to DEFAULT_FORMAT
|
|
|
|
|
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
|
|
|
|
// set value node initial device data type = infer data type
|
|
|
|
|
TypeId infer_data_type;
|
|
|
|
|
if (AnfAlgo::GetOutputTensorNum(value_node) == 0) {
|
|
|
|
|
infer_data_type = kTypeUnknown;
|
|
|
|
|
} else {
|
|
|
|
|
infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0);
|
|
|
|
|
}
|
|
|
|
|
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{infer_data_type});
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node,
|
|
|
|
|
std::vector<AnfNodePtr> *plant_inputs, std::vector<int> *dyn_input_sizes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(plant_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dyn_input_sizes);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto output_size = AnfAlgo::GetOutputTensorNum(input_node);
|
|
|
|
|
dyn_input_sizes->push_back(output_size);
|
|
|
|
|
std::vector<AnfNodePtr> convert_inputs;
|
|
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
if (input_node->isa<ValueNode>()) {
|
|
|
|
|
auto value_node = input_node->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node);
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t index = 0; index < output_size; ++index) {
|
|
|
|
|
auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)},
|
|
|
|
|
{AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get());
|
|
|
|
|
convert_inputs.emplace_back(tuple_get_item);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
(void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto &ori_args = cnode_ptr->inputs();
|
|
|
|
|
if (ori_args.size() < 1) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> plant_inputs;
|
|
|
|
|
std::vector<int> dyn_input_sizes;
|
|
|
|
|
plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]);
|
|
|
|
|
for (size_t i = 1; i < ori_args.size(); ++i) {
|
|
|
|
|
auto input_node = ori_args[i];
|
|
|
|
|
if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
|
|
|
|
|
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) {
|
|
|
|
|
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
|
|
|
|
auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
|
|
|
|
|
dyn_input_sizes.push_back(input_size);
|
|
|
|
|
auto cnode = input_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto inputs = cnode->inputs();
|
|
|
|
|
for (size_t j = 1; j < inputs.size(); ++j) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs[j]);
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(inputs[j])) {
|
|
|
|
|
auto success = MakeValueNode(inputs[j]);
|
|
|
|
|
auto make_tuple = input_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) {
|
|
|
|
|
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
|
|
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
|
|
|
|
|
if (!success) {
|
|
|
|
|
MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString();
|
|
|
|
|
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
plant_inputs.push_back(inputs[j]);
|
|
|
|
|
plant_inputs.push_back(dyn_input_node);
|
|
|
|
|
}
|
|
|
|
|
} else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) {
|
|
|
|
|
ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes);
|
|
|
|
|
} else {
|
|
|
|
|
dyn_input_sizes.push_back(-1);
|
|
|
|
|
plant_inputs.push_back(input_node);
|
|
|
|
@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
|
|
|
|
|
for (auto &t : todos) {
|
|
|
|
|
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|